In [64]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [65]:
print(f'Length of dataset in characters: {len(text)}')

Length of dataset in characters: 1115394


In [66]:
print(text[:1000])  # print the first 1000 characters to check the data

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



Unique characters in the dataset:

In [67]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size) # integers that range from0 to 64


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


Tokenizer (from characters to integers (encoder) and back (decoder)):

In [68]:
stoi = { ch:i for i,ch in enumerate(chars) } # mapping in a dict each char to its index in the unique char list
print(stoi)

{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}


In [69]:
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # from string to list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # from list of integers to string

print(encode("hii there"))
print(decode(encode("hii there"))) # encode*decode is just null

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [70]:
import torch
data = torch.tensor(encode(text), dtype=torch.long) # we encode all the text in shakespeare and convert it into a tensor
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

Train validation split:

In [71]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
len(train_data), len(val_data) # 1M chars training

(1003854, 111540)

In [72]:
block_size = 8 # context length (total of chars going at the same time through the transformer)
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

We will train it on block_size+1 because we want the next character as target.

We will train on all sequences of length 1 up to block_size in the text, not because we don't have the complete sequence, but because we want the model to learn to predict the next character given any context length from 1 up to block_size.

In [73]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [74]:
torch.manual_seed(1337)
batch_size = 4 # sequences we will process in parallel
block_size = 8 # context window

def get_batch(split, verbose=False): # split is train or val
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    if verbose:
        print(f'Starting token positions for each sequence of the {batch_size} ran in parallel: {ix}')
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train', verbose=True)
print(xb)
print(xb.shape) # input size of the transformer
print(yb) # same as xb but with an offset=1
print(yb.shape)

Starting token positions for each sequence of the 4 ran in parallel: tensor([ 76049, 234249, 934904, 560986])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
torch.Size([4, 8])


In [75]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1] # [row, col] --> [n_batch, input_sequence_len]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53, 56, 1, 58] the target: 46
when input is [44, 53, 56, 1, 58, 46] the target: 39
when input is [44, 53, 56, 1, 58, 46, 39] the target: 58
when input is [44, 53, 56, 1, 58, 46, 39, 58] the target: 1
when input is [52] the target: 58
when input is [52, 58] the target: 1
when input is [52, 58, 1] the target: 58
when input is [52, 58, 1, 58] the target: 46
when input is [52, 58, 1, 58, 46] the target: 39
when input is [52, 58, 1, 58, 46, 39] the t

In [76]:
print(xb) # input to the transformer

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


In [77]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (Batch=4,Time=8,Channel=vocab_size=65)

        if targets is None:
            loss = None
        else:
            # pytorch requires (B,C,T) and not (B,T,C) for cross_entropy
            B,T,C = logits.shape
            logits = logits.view(B*T, C) # view is like reshape
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets) # logits are predictions

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get predictions
            logits, loss = self(idx) # self makes it run the forward function
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B,C), so the prob for each token to be the next, and that for each seq
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B,C)
            # sample from the distribution: it doesn't take the highest prob, but follows the probs we just predicted
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # append sampled index to the running sequence (adds predicted next token to the x tensor)
            idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
            # and goes again to the top with now one token more on the x seq to predict another, until max_new_tokens
        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(logits)
print(loss) # should be about -ln(1/65)=4.17 , but with some entropy it has it rounds up to 4.87

torch.Size([32, 65])
tensor([[-1.5101, -0.0948,  1.0927,  ..., -0.6126, -0.6597,  0.7624],
        [ 0.3323, -0.0872, -0.7470,  ..., -0.6716, -0.9572, -0.9594],
        [ 0.2475, -0.6349, -1.2909,  ...,  1.3064, -0.2256, -1.8305],
        ...,
        [-2.1910, -0.7574,  1.9656,  ..., -0.3580,  0.8585, -0.6161],
        [ 0.5978, -0.0514, -0.0646,  ..., -1.4649, -2.0555,  1.8275],
        [-0.6787,  0.8662, -1.6433,  ...,  2.3671, -0.7775, -0.2586]],
       grad_fn=<ViewBackward0>)
tensor(4.8786, grad_fn=<NllLossBackward0>)


In [78]:
# we choose zero to be the first char; reminder that zero is a \n character
print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long), # batch=1, time=1
                        max_new_tokens=100)[0].tolist()))
# generates a totally random sequence


SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ




Now let's train the model:

In [79]:
# create a pytorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [80]:
batch_size = 32 # instead of the previous 4
for i,steps in enumerate(range(10000)):
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if i%100==0:
        print(loss.item())

4.692410945892334
4.621085166931152
4.549462795257568
4.345612049102783
4.255731582641602
4.214480876922607
4.124096870422363
3.9863951206207275
3.9517807960510254
3.837888717651367
3.7637593746185303
3.6824676990509033
3.533822536468506
3.513597011566162
3.4971799850463867
3.3378093242645264
3.3668527603149414
3.2826080322265625
3.1327052116394043
3.160909414291382
3.2342257499694824
2.997836112976074
3.0942726135253906
2.9780402183532715
2.890953302383423
2.939120292663574
2.8254289627075195
2.921311378479004
2.886559247970581
2.8697657585144043
2.892245292663574
2.7563703060150146
2.6004951000213623
2.627633571624756
2.7147138118743896
2.718296766281128
2.714982748031616
2.606290817260742
2.723784923553467
2.60630464553833
2.703908681869507
2.7407634258270264
2.6153855323791504
2.6925723552703857
2.5623557567596436
2.6690523624420166
2.595306396484375
2.5762505531311035
2.5814590454101562
2.531094789505005
2.515348434448242
2.4976727962493896
2.5762481689453125
2.6668245792388916
2.

In [81]:
# we use the decoder again for a prediction, expanding the predicting tokens to 500
print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


lso br. ave aviasurf my, yxMPZI ivee iuedrd whar ksth y h bora s be hese, woweee; the! KI 'de, ulseecherd d o blllando;LUCEO, oraingofof win!
RIfans picspeserer hee tha,
TOFonk? me ain ckntoty ded. bo'llll st ta d:
ELIS me hurf lal y, ma dus pe athouo
BEY:! Indy; by s afreanoo adicererupa anse tecorro llaus a!
OLeneerithesinthengove fal amas trr
TI ar I t, mes, n IUSt my w, fredeeyove
THek' merer, dd
We ntem lud engitheso; cer ize helorowaginte the?
Thak orblyoruldvicee chot, p,
Bealivolde Th li


Until here we have just used the bigram model, so it makes little sense as it has very small context window on previous tokens.

Let's now build a proper transformer model.

## The mathematical trick in self-attention

In [82]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

Tokens should only attend to previous tokens, not future ones.

We'll start by calculating the mean of the previous tokens + the current for all channels.

In [83]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C)) # bow = bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

In [84]:
xbow

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

Okay that's good, but very inefficient. Let's do it with matrix operations using the trick.

In [85]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('b=')
print(b)
print('c=')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


If we look carefully, we can see that the output at time t is just the sum of the first t rows of the input.

This is because of the magic that happens by using the lower triangular matrix to only select the tokens up to t.

Now let's do the mean instead of the sum.

In [86]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('b=')
print(b)
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


Let's now do this for the real matrix x of shape (B,T,C)

In [87]:
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [88]:
xbow2 = weights @ x # (T, T) @ (B, T, C) --> (B, T, T) @ (B, T, C) --> (B, T, C)
torch.allclose(xbow, xbow2, atol=1e-6) # same result, more efficient

True

In [89]:
xbow

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

Another final version of the same, now with Softmax:

In [96]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
print(wei)
wei = F.softmax(wei, dim=-1)
print(wei)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


Softmax first exponents all elements, so -inf to 0 and 0 to 1, then normalizes the rows to sum to 1.

So first step is:
```python
[1,0,0,0,0...],
[1,1,0,0,0...],
[1,1,1,0,0...],
[1,1,1,1,0...],
```

And second step is:
```python
[1,0,0,0,0...],
[0.5,0.5,0,0,0...],
[0.33,0.33,0.33,0,0...],
[0.25,0.25,0.25,0.25,0...],
```

In [101]:
xbow3 = wei @ x
torch.allclose(xbow, xbow3, atol=1e-6)

True

TLDR: you can do weighted aggregation of your past elements by making matrix multiplication of your lower triangular matrix with the elements you want to aggregate.

## Implementing self-attention

Let's make a fourth version with self-attention

In [None]:
torch.manual_seed(1337)
B,T,C = 4,8,32 # we will use a 32-dimension embedding for each token
x = torch.randn(B,T,C)

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

out = wei @ x
out.shape

torch.Size([4, 8, 32])

Actually, in attention we have queries (what we are looking for) and keys (what we have).

The multiplication of all queries with all keys gives us a score of how much each key matches each query.

That is equivalent to wei.

In [None]:
# let's see a single Head perform self-attention

torch.manual_seed(1337)
B,T,C = 4,8,32 # we will use a 32-dimension embedding for each token
x = torch.randn(B,T,C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
wei = q @ k.transpose(-2, -1) * C**-0.5 # we want to transpose the last two dimensions: T and head_size --> (B, T, 16) @ (B, 16, T) --> (B, T, T)     C**-0.5 is explained below
torch.set_printoptions(precision=2, sci_mode=False)
print('Raw connections between keys and queries:\n', wei[0])

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T)) --> not anymore
wei = wei.masked_fill(tril==0, float('-inf'))
print('Only lower triangle, we cant use the tokens after our token:\n', wei[0])
wei = F.softmax(wei, dim=-1)
print('After softmax normalization:\n', wei[0])

v = value(x)
#out = wei @ x
out = wei @ v # we use v instead of x, so it'll be 16-dimensional (head_size) instead of 32
out.shape

Raw connections between keys and queries:
 tensor([[-0.31, -0.23,  0.10,  0.38, -0.19,  0.35,  0.19, -0.08],
        [-0.59, -0.29,  0.02,  0.60, -0.39,  0.18, -0.01,  0.05],
        [-0.18, -0.22,  0.01, -0.07, -0.17, -0.25,  0.01, -0.17],
        [ 0.14, -0.14, -0.06, -0.15, -0.10, -0.21, -0.23, -0.18],
        [-0.22,  0.00, -0.14, -0.23,  0.36,  0.15,  0.07,  0.16],
        [-0.06,  0.43, -0.02, -0.18,  0.59, -0.45,  0.25,  0.22],
        [ 0.19,  0.35, -0.05, -0.06,  0.11,  0.22, -0.10,  0.14],
        [-0.32, -0.07, -0.15,  0.10, -0.14, -0.10,  0.11,  0.11]],
       grad_fn=<SelectBackward0>)
Only lower triangle, we cant use the tokens after our token:
 tensor([[-0.31,  -inf,  -inf,  -inf,  -inf,  -inf,  -inf,  -inf],
        [-0.59, -0.29,  -inf,  -inf,  -inf,  -inf,  -inf,  -inf],
        [-0.18, -0.22,  0.01,  -inf,  -inf,  -inf,  -inf,  -inf],
        [ 0.14, -0.14, -0.06, -0.15,  -inf,  -inf,  -inf,  -inf],
        [-0.22,  0.00, -0.14, -0.23,  0.36,  -inf,  -inf,  -inf],
  

torch.Size([4, 8, 16])

We have one thing left to explain: the division by sqrt(dk) where dk is the head size.

We need to do this to avoid that the variance of the dot products grows too large with the dimension, making the softmax have extremely small gradients.

In [159]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2,-1)

print(k.var())
print(q.var())
print(wei.var())

tensor(1.00)
tensor(1.06)
tensor(17.60)


In [160]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1)

tensor([0.03, 0.00, 0.16, 0.00, 0.80])

In [161]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2,-1) * head_size**-0.5

print(k.var())
print(q.var())
print(wei.var())

tensor(1.02)
tensor(0.97)
tensor(0.99)


In [162]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.19, 0.14, 0.24, 0.14, 0.29])

## Multi-head attention

Running multiple attention heads in parallel and concateninating their results.

This allows the model to jointly attend to information from different representation subspaces at different positions (this is, attend to different things/patterns at the same time, see the relationships from different perspectives).

*Implemented on the multihead-attention.py script.