### Causal Attention
Simplifications: Single head, no mini-batches, no padding, linear only.

In [1]:
import torch
from torch import nn
torch.set_printoptions(precision=2, linewidth=200)

In [2]:
V = 100 # Vocab size
d = 768 # embedding dims
T = 5   # seq length

In [3]:
emb = torch.nn.Embedding(V, d) # embeddings
clf = torch.nn.Linear(d, V)    # classification head

In [4]:
q_proj, k_proj, v_proj = nn.Linear(d, d), nn.Linear(d, d), nn.Linear(d, d)

In [5]:
tokens = torch.randint(V, (T, )); tokens, tokens.shape

(tensor([95, 11, 75, 30, 31]), torch.Size([5]))

In [6]:
x = emb(tokens)

In [7]:
q, k, v  = q_proj(x), k_proj(x), v_proj(x); q.shape

torch.Size([5, 768])

Below: Parallel computation -> resulting attention matrix is $\mathbb{R} ^{TxT}$

In [8]:
raw_att = q@k.t()/torch.sqrt(torch.tensor(d)); raw_att, raw_att.shape

(tensor([[-0.82, -0.36, -0.15,  0.76, -0.32],
         [ 0.03, -0.23, -0.01,  0.25, -0.73],
         [ 0.43,  0.37, -0.27,  0.20, -0.52],
         [ 0.19, -0.01,  0.19,  0.06, -0.21],
         [-0.04,  0.16, -0.30, -0.12, -0.27]], grad_fn=<DivBackward0>),
 torch.Size([5, 5]))

Below: Each row defines the scores (=pre-attention) of each column in the target sequence. The score ultimately indicates how much do we take from the input sequence's columns into the output sequence's columns.

E.g. The third row means that we take values from the first three elements in the input sequence, but not from the remaining two elements, as we are just preparing the information that we use later to predict the third token in the output sequence; we can't see into the future because of the masking.

Btw. the next token prediction is separate from that. We'll see this at the end in the section about [Loss](#Loss).

In [9]:
masked_att = torch.tril(raw_att); masked_att, masked_att.shape

(tensor([[-0.82,  0.00,  0.00,  0.00,  0.00],
         [ 0.03, -0.23,  0.00,  0.00,  0.00],
         [ 0.43,  0.37, -0.27,  0.00,  0.00],
         [ 0.19, -0.01,  0.19,  0.06,  0.00],
         [-0.04,  0.16, -0.30, -0.12, -0.27]], grad_fn=<TrilBackward0>),
 torch.Size([5, 5]))

In [10]:
masked_att = masked_att.masked_fill_(masked_att == 0., float('-inf')); masked_att, masked_att.shape

(tensor([[-0.82,  -inf,  -inf,  -inf,  -inf],
         [ 0.03, -0.23,  -inf,  -inf,  -inf],
         [ 0.43,  0.37, -0.27,  -inf,  -inf],
         [ 0.19, -0.01,  0.19,  0.06,  -inf],
         [-0.04,  0.16, -0.30, -0.12, -0.27]], grad_fn=<MaskedFillBackward0>),
 torch.Size([5, 5]))

In [11]:
masked_att = torch.softmax(masked_att, -1); masked_att, masked_att.shape

(tensor([[1.00, 0.00, 0.00, 0.00, 0.00],
         [0.56, 0.44, 0.00, 0.00, 0.00],
         [0.41, 0.39, 0.20, 0.00, 0.00],
         [0.27, 0.22, 0.27, 0.24, 0.00],
         [0.21, 0.26, 0.16, 0.20, 0.17]], grad_fn=<SoftmaxBackward0>),
 torch.Size([5, 5]))

In [12]:
new_v = masked_att @ v; new_v, new_v.shape

(tensor([[-0.55,  0.09, -0.22,  ..., -0.33, -0.84,  0.01],
         [-0.56, -0.08, -0.21,  ..., -0.18, -0.67,  0.24],
         [-0.46,  0.01,  0.05,  ..., -0.25, -0.41,  0.30],
         [-0.44,  0.09,  0.06,  ..., -0.24, -0.14,  0.39],
         [-0.48, -0.11, -0.08,  ..., -0.06, -0.19,  0.39]], grad_fn=<MmBackward0>),
 torch.Size([5, 768]))

### Interlude: Sanity check against the PyTorch built-in scaled dot product?

In [13]:
attn_mask = torch.tril(torch.ones(T, T)) == 1; attn_mask

tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])

In [14]:
torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False)

tensor([[-0.55,  0.09, -0.22,  ..., -0.33, -0.84,  0.01],
        [-0.56, -0.08, -0.21,  ..., -0.18, -0.67,  0.24],
        [-0.46,  0.01,  0.05,  ..., -0.25, -0.41,  0.30],
        [-0.44,  0.09,  0.06,  ..., -0.24, -0.14,  0.39],
        [-0.48, -0.11, -0.08,  ..., -0.06, -0.19,  0.39]], grad_fn=<MmBackward0>)

In [15]:
torch.allclose(_, new_v, atol=1e-5)

True

In [16]:
torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True)

tensor([[-0.55,  0.09, -0.22,  ..., -0.33, -0.84,  0.01],
        [-0.56, -0.08, -0.21,  ..., -0.18, -0.67,  0.24],
        [-0.46,  0.01,  0.05,  ..., -0.25, -0.41,  0.30],
        [-0.44,  0.09,  0.06,  ..., -0.24, -0.14,  0.39],
        [-0.48, -0.11, -0.08,  ..., -0.06, -0.19,  0.39]], grad_fn=<MmBackward0>)

In [17]:
torch.allclose(_, new_v, atol=1e-5)

True

### Loss
Predict next token

In [18]:
new_v.shape, clf

(torch.Size([5, 768]), Linear(in_features=768, out_features=100, bias=True))

In [19]:
# ... some more layers ...
logits = clf(new_v); logits.argmax(-1), logits.shape, logits.argmax(-1).shape # The argmax gives us the token indices. CE loss below does this implicitly.

(tensor([ 0, 66, 81, 43, 87]), torch.Size([5, 100]), torch.Size([5]))

In [20]:
targets = torch.arange(0, T); targets # dummy values

tensor([0, 1, 2, 3, 4])

In [21]:
nn.functional.cross_entropy(logits, targets, ignore_index=-1) # -1 an example for a padding index!

tensor(4.54, grad_fn=<NllLossBackward0>)