In [1]:
import requests

url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)

with open("input.txt", "w", encoding="utf-8") as f:
    f.write(response.text)

print("Saved as input.txt ✅")

Saved as input.txt ✅


In [2]:
with open('input.txt', 'r', encoding='utf-8') as f: # Read the file content to inspect it
    text = f.read()

In [3]:
print(" Length of dataset in characters:", len(text))

 Length of dataset in characters: 1115394


In [4]:
# Print first 1000 characters to verify
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [5]:
# store all the unique characters in a set
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(chars)
print(" All the unique characters in the dataset:", ''.join(chars))
print (vocab_size)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
 All the unique characters in the dataset: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [6]:
# Create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder function: takes a string and returns a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder function: takes a list of integers and returns a string
print("Encoding 'hello':", encode("hello"))
print(decode(encode("hello")))  # should return 'hello'


Encoding 'hello': [46, 43, 50, 50, 53]
hello


import tiktoken

enc = tiktoken.get_encoding("gpt2")

print("Encoding 'hello' with tiktoken:", enc.encode("hello"))

print("Decoding tiktoken encoding:", enc.decode(enc.encode("hello")))  # should return 'hello'

enc.n_vocab  # number of tokens in the vocabulary

print("Number of tokens in the vocabulary:", enc.n_vocab)

enc.encode("hello")

enc.decode(enc.encode("hello"))  # should return 'hello'

In [7]:
# now let's encode the entire text dataset and store it in a torch.Tensor
import torch
data = torch.tensor(encode(text), dtype = torch.long) # convert the list of integers to a torch.Tensor
print("Encoded data shape and datatype:", data.shape)
print("First 1000 characters of encoded data:", data[:1000]) # the 1000 characters we looked at earlier will look to the GPT model like this
print("Decoded first 1000 characters:", decode(data[:1000].tolist()))  # convert the first 1000 integers back to characters to verify correctness

Encoded data shape and datatype: torch.Size([1115394])
First 1000 characters of encoded data: tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52

In [8]:
# Let's split the data into train and validation sets
n = int(0.9*len(data)) # 90% for training, 10% for validation
train_data = data[:n]  # to train
val_data = data[n:] # to validate and to check if the model is overfitting, we do not want model to memorize the training data
print("Train data shape:", train_data.shape)
print("Validation data shape:", val_data.shape)

Train data shape: torch.Size([1003854])
Validation data shape: torch.Size([111540])


In [9]:
block_size = 8 # how many characters to predict at once
train_data[:block_size+1]  # the first 9 characters of the training data
# we will use this to train the model, the first 8 characters will be the input
# and the 9th character will be the target output
# for example, if the first 9 characters are "hello worl", we will use
# "hello worl" as input and "hello world" as target output



tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [10]:
x= train_data[:block_size]  # the first 8 characters
y = train_data[1:block_size+1]  # the next 8 characters, which is the target output
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [11]:
# Batch dimensions
torch.manual_seed(1337)  # for reproducibility
batch_size = 4  # how many independent sequences will we process in parallel?
block_size = 8  # what is the maximum context length for predictions?
# Each sequence will be of length block_size, and we will have batch_size sequences in parallel
# For example, if we have a sequence of length 32, we can create 4 sequences of length 8
# The first sequence will be the first 8 characters, the second sequence will be the next 8 characters, and so on

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))  # random starting indices for each sequence in the batch
    x = torch.stack([data[i:i+block_size] for i in ix])  # input sequences
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])  # target sequences, shifted by one character
    return x, y

# Example usage
xb, yb = get_batch('train')
print("Input batch shape:", xb.shape)  # should be (batch_size, block_size
print("Target batch shape:", yb.shape)  # should be (batch_size, block_size)
print("Input batch:", xb)  # the input sequences
print("Target batch:", yb)  # the target sequences, shifted by one character

print('--'*50)

for b in range(batch_size): # for each sequence in the batch
    for t in range(block_size):  # for each character in the sequence
        context = xb[b, :t+1]  # input sequence up to the t-th character
        target = yb[b, t]  # target character at position t
        print(f"Batch {b}, when input is {context.tolist()} the target: {target.item()}")
        

Input batch shape: torch.Size([4, 8])
Target batch shape: torch.Size([4, 8])
Input batch: tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
Target batch: tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----------------------------------------------------------------------------------------------------
Batch 0, when input is [24] the target: 43
Batch 0, when input is [24, 43] the target: 58
Batch 0, when input is [24, 43, 58] the target: 5
Batch 0, when input is [24, 43, 58, 5] the target: 57
Batch 0, when input is [24, 43, 58, 5, 57] the target: 1
Batch 0, when input is [24, 43, 58, 5, 57, 1] the target: 46
Batch 0, when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
Batch 0, when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
Batch 1, wh

In [13]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):  # inherits from nn.Module
    def __init__(self, vocab_size):  # vocab_size is the number of unique characters in the dataset
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None): # idx is the input sequence, targets is the target sequence
        logits = self.token_embedding_table(idx)  # get the logits for the next token
        if targets is not None:  # if we have targets, we compute the loss
            # reshape the logits to (batch_size, block_size, vocab_size) - (B, T, C)
            B, T, C = logits.shape
            logits = logits.view(B * T, C)  # reshape to (batch_size * block_size, vocab_size)
            targets = targets.view(B * T) # reshape targets to (batch_size * block_size)
            loss = F.cross_entropy(logits, targets) # compute the loss(cross entropy) between the logits and the targets
            # logits shape: (batch_size, block_size, vocab_size)
        else:
            loss = None  # if we do not have targets, we do not compute the loss
        return logits, loss  # return the logits and the loss

    def generate(self, idx, max_new_tokens):  # idx is the input sequence, max_new_tokens is the number of tokens to generate
        for _ in range(max_new_tokens):  # generate max_new_tokens tokens
            logits, loss = self(idx)  # forward pass, we do not need targets for generation
            logits = logits[:, -1, :]  # take the logits for the last character in the sequence (B, C)
            probs = F.softmax(logits, dim=-1)  # convert logits to probabilities (B, C)
            idx_next = torch.multinomial(probs, num_samples=1)  # sample from the probabilities (B, 1)
            idx = torch.cat((idx, idx_next), dim=1)  # append the sampled token to the input sequence (B, T+1) 
        return idx  # return the generated sequence

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)  # forward pass
print(logits.shape)  # should be (batch_size, block_size, vocab_size)
print(loss.item())  # print the loss value
#idx = torch.zeros((1, 1), dtype=torch.long)  # start with a single token (the first character)
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))  # generate 100 new tokens and decode them to a string


torch.Size([32, 65])
4.878634929656982

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


the generated output is garbage as we didn't train it using history, which we will do later.

In [14]:
# create a pytorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)  # AdamW optimizer with a learning rate of 1e-3 ()

In [24]:
batch_size = 32  # how many independent sequences will we process in parallel?
for steps in range(10000):  # train for 100 steps
    # sample a batch of data
    xb, yb = get_batch('train')  # get a batch of data
    # evaluate the loss
    logits, loss = m(xb, yb)  # forward pass
    optimizer.zero_grad(set_to_none=True)  # zero the gradients
    loss.backward()  # backpropagation
    optimizer.step()  # update the parameters
print(loss.item())  # print the loss value

2.311746120452881


In [26]:
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long), max_new_tokens=1000)[0].tolist()))  # generate 100 new tokens and decode them to a string


Inghibey?

Wher t Moman llou f thar werd felld:
CIns,
Shtonotre de.
Nureendyoous:
PULAng, t? yothe'sefofad y hapareove.
CENI t f s fy d l the, ys.

TI,
An ceatmisce jesshe ee.
CH: t thed tewenoncrgis an mp?
Bunevest-
Bu oucy,
Hen latas t
LYoflin t o ldgur wey;
Air? LIO, nderis, wemysoncer INI wir oul thalivoccouy d,


MI myomave I nd-then, d whanghe ther am m.
Toupove y
Wavie athe he indr:

and
Pracor jer leathewomicefownge 'shyouclen he t, feret, bind ig qu;
Coy IS:
Yolllet juns w t y
Fonowangonesthed cealou none sur ply,
Nombu al?

And nom ma,
SBuburou'l bital ts theendist, wo Be; alllo.
sadacew m yos o f gelllooter, wbene d y s. mounthigrave he hal irbe, t?
HONu as, as os?
Thirsin porey d
TI age
I s yeryouresthovo atir s,

Pi's t f tr, du?
EDUSobeeathichemeg g y toul IULAlmithu ke, w bhe ne uen thery RENToon, IZEEDUCI cend ot l me's angice;
Thewareratulyendutonde 's thie,
fr arnd ic-mybem t m'dsent ee, blouged fere fearau se fofouriey'd!
LANEThellld ogo withend tequ
Y ldill, totooo

Here, we can see, after training for rougly a million iterations, it sort of looks like shakespeare but definitely not it. 
We are only seeing the last character to predict the next one, but now we will try to make the characters to start to talk to each otehr and figuring what is the context so that they can predict what comes nect.

# The mathematical trick in self-attention

The tokesn shouldn't talk to future tokens, but only previous ones. the information should donly flow from previous contexts to the current timestep and predict the future.

In [29]:
# consider the following example

torch.manual_seed(1337)
B, T, C = 4, 8, 2  # batch size, sequence length, number of channels (vocab size)
x = torch.randn(B, T, C)  # random input tensor
x.shape  # should be (4, 8, 2)


torch.Size([4, 8, 2])

# Version 1

In [41]:
# we want  x[b,t]  = mean_{i<=t} x[b,i]  (the mean of all the previous time steps including the current one)
# we can do this using a for loop, but it is least efficient
xbow = torch.zeros((B, T, C))  # initialize the output tensor bag of words
for b in range(B):  # for each sequence in the batch
    for t in range(T):  # for each time step in the sequence
        xprev = x[b, :t+1]  # all previous time steps including the current one (t, C)
        xbow[b, t] = torch.mean(xprev, 0)  # mean over the time dimension
print(x[0],'\n', xbow[0])

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]]) 
 tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


# Version 2

In [None]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a= a / torch.sum(a, 1, keepdim=True)  # normalize the rows- sum of each row is 1
b = torch.randint(0,10, (3, 2)).float()  # random tensor with values between 0 and 10, shape (3, 2)
c = a@b  # matrix multiplication
print("a:\n", a)
print("b:\n", b)
print("c:\n", c)  # should be a matrix of shape (3, 2)
# here, the rows are multiplied by the corresponding rows of b and summed up, but since a is lower triangular.

a:
 tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b:
 tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c:
 tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [None]:
weights = torch.tril(torch.ones(T, T))  # lower triangular matrix of shape (T, T)
weights = weights / weights.sum(1, keepdim=True)  # normalize the rows
weights
xbow2 = weights @ x # this is a batch matrix multiplication in parallel (B, T, T) @ (B, T, C) -> (B, T, C)
torch.allclose(xbow[0], xbow2[0]) 


True

# Version 3 (Use softmax)

In [55]:
weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float('-inf')) # for all tril elements that are zer0, make them -inf
weights

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float('-inf')) # for all tril elements that are zero, make them -inf
weights = F.softmax(weights, dim=-1)
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [49]:
tril = torch.tril(torch.ones(T, T))
weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim = -1)
xbow3 = weights @ x
torch.allclose(xbow, xbow3)

True

🔤 vocab_size

The total number of unique tokens (characters or words) in your dataset.

You use vocab_size:
- In nn.Embedding(vocab_size, ...): Because each token index must map to a unique embedding vector.
- In nn.Linear(..., vocab_size): Because your model must output a probability distribution over all possible next tokens.

👉 Think:

“How many different outputs do I have to predict from?” — That’s your vocab_size.

⸻

📐 n_embd (embedding dimension)

How many features you want for each token or position embedding.

You use n_embd:
- In nn.Embedding(vocab_size, n_embd): to embed each token as a vector of size n_embd.
- In nn.Embedding(block_size, n_embd): to embed each position as a vector of size n_embd.
- In nn.Linear(n_embd, vocab_size): to transform the internal representation into logits over the vocabulary.

👉 Think:

“How big is the vector I use to represent each token and position?” — That’s your n_embd.

⸻

🔢 block_size

The maximum context length the model can see at once (how many tokens in one input sequence).

You use block_size:
- In nn.Embedding(block_size, n_embd): Because each position from 0 to block_size - 1 needs its own embedding.
- For slicing input text into fixed-size chunks of length block_size.

👉 Think:

“How far back can the model look when predicting the next token?” — That’s your block_size.

# Version 4: self-attention!!!!!


In [69]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32   # B=batch size, T=sequence length (block size), C=embedding dim(channels)
x = torch.randn(B, T, C)

# let's see a single Head perform self-attention
head_size = 16       # Number of output channels for each attention head
key = nn.Linear(C, head_size, bias = False)
query = nn.Linear(C, head_size, bias = False)
value = nn.Linear(C, head_size, bias = False)

# Every token in (B, T) arrangement produce 2 vectors independently - query(what am i looking for) and key(what do i contain) - no communications yet
# value is an element or vector that is used to aggregate instead of raw 'x'. It represents what the token provide as an output.
# By forwarding x into these modules, we will get:
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
v = value(x) 

# all the queries will dot product with all the keys.

weights = q @ k.transpose(-2,-1)   # (B, T, 16) @ (B, 16, T) ---> (B, T, T) for every row of B, we have a T*T matrix giving us affinities
# This gives the attention score between each pair of tokens in the sequence: For each pair of token positions (i, j), how much token i should attend to token j.



tril = torch.tril(torch.ones(T, T)) # lower triangular matrix of shape (T, T)
#weights = torch.zeros((T, T))
weights = weights.masked_fill(tril==0, float('-inf')) # for all tril elements that are zero, make them -inf (This prevents the model from “cheating” by looking at future tokens when predicting the current one — making it causal.)
weights = F.softmax(weights, dim=-1) # Each row in wei[b][i] becomes a probability distribution over the previous tokens j <= i
#out = weights @ x
out = weights @ v #Each token gets a weighted sum of the value vectors of all previous tokens — based on how much it “attends” to each.
out.shape


torch.Size([4, 8, 16])

In [70]:
weights[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Notes:

- Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below