In this notebook we'll implement a single self-attention head.

In [10]:
import torch
from pyxtend import struct
from torch import nn

In [3]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f88a84d62b0>

In [4]:
B,T,C = 4,8,2 # batch, time, channel

In [5]:
x = torch.randn(B,T,C)

In [6]:
struct(x)

{'Tensor': ['torch.float32, shape=(4, 8, 2)']}

We have 8 tokens in a batch. They are currently not "talking" to each other, but we would like to get them to. Thus, self-attention.
It shouldn't communicate with future tokens though.

How does self-attention do this? Every single position emits two vectors - a query and a key
* Query vector - what am i looking for
* key vector - what do i contain

do dot product between keys and queries to get the affinity between the two.

In [7]:
C

2

In [8]:
head_size = 16

In [25]:
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)


In [26]:
struct(key)

'unsupported'

Remember, this linear layer is just a matrix of weights (potentially with a bias matrix as well, but not in this case).

In [30]:
# Print the weights (matrix A)
print("Weights:")
print(key.weight)

Weights:
Parameter containing:
tensor([[ 0.4520,  0.6077],
        [-0.0700, -0.1583],
        [ 0.0103, -0.0422],
        [ 0.1700,  0.1982],
        [-0.6422, -0.2609],
        [ 0.5955,  0.2755],
        [-0.0352, -0.4263],
        [-0.4326, -0.6334],
        [-0.2305,  0.2388],
        [ 0.4509,  0.3265],
        [-0.6250, -0.4252],
        [-0.1116,  0.6840],
        [ 0.1023, -0.1831],
        [ 0.2925, -0.2693],
        [-0.4577,  0.5161],
        [-0.3215, -0.1418]], requires_grad=True)


In [13]:
k = key(x)
q = query(x)

In [15]:
struct(k) # B,T,head_size

{'Tensor': ['torch.float32, shape=(4, 8, 16)']}

`wei` are the weights only looking back. So like half the weights.

In [16]:
wei = torch.einsum('bth,bth->bt', q, k) / head_size**0.5

In [17]:
wei

tensor([[-0.2144, -0.0022,  0.0514, -0.2483,  0.0195, -0.2520, -0.0301,  0.1136],
        [ 0.3697,  0.0478,  0.0955, -0.4000, -0.2236, -0.2507,  0.0080, -0.0254],
        [-0.0311, -0.2036,  0.0177, -0.0055,  0.4563, -0.0244,  0.2212,  0.0667],
        [-0.0353, -0.0314, -0.0075, -0.0660,  0.2487, -0.0346, -0.2832,  0.2605]],
       grad_fn=<DivBackward0>)

In [31]:
wei.var()

tensor(0.7492, grad_fn=<VarBackward0>)

wei gets too big if you don't normalize it. That's why we need head_size**0.5

In [18]:
struct(wei)

{'Tensor': ['torch.float32, shape=(4, 8)']}

In [19]:
head_size**0.5

4.0

In [21]:
# wei = q @ k.transpose(-1,-2) / head_size**0.5 # why sqrt of head_size?
wei = q @ k.transpose(-1,-2)

In [22]:
wei

tensor([[[-0.8577,  1.8921,  1.1375,  1.3057, -1.4903,  1.0758,  0.3221,
          -0.7678],
         [ 3.8746, -0.0090,  0.2215, -1.1078, -0.0636, -1.5045, -1.4620,
          -0.6648],
         [ 2.3820,  0.1053,  0.2058, -0.6189, -0.1273, -0.8814, -0.8989,
          -0.4624],
         [ 2.4181, -0.5436, -0.1994, -0.9932,  0.3885, -1.1503, -0.9120,
          -0.1545],
         [-3.0683, -0.0274, -0.1971,  0.8579,  0.0779,  1.1778,  1.1578,
           0.5432],
         [ 1.8550, -0.7365, -0.3536, -0.9412,  0.5523, -1.0080, -0.6993,
           0.0362],
         [ 0.3204, -0.7140, -0.4294, -0.4918,  0.5624, -0.4047, -0.1203,
           0.2903],
         [-1.7276, -0.3224, -0.3037,  0.3108,  0.2882,  0.5425,  0.6522,
           0.4545]],

        [[ 1.4788, -0.5150, -0.8550,  0.4193,  0.9079,  0.3104, -0.2119,
          -0.3699],
         [-0.7511,  0.1914,  0.2540, -0.5285, -0.7161, -0.4078,  0.1041,
           0.0380],
         [-1.4612,  0.3286,  0.3820, -1.2245, -1.5517, -0.9490,  0.2

But we don't want all lines of code to talk to each other, so we need to do some masking.

It's a decoder block because it's decoding language. So I think deocder only means it's using masking??? Is this true?
This is also known as auto-regressive.

Same source `x` produces keys, queries, and values, so this is self-attention. Keys and values could come from a different source.

Show the formula from the paper here. And the diagram.