In [243]:
torch.cuda.is_available()

False

In [18]:
# read the contents of the file into a variable called text
with open( 'input.txt', 'r', encoding='utf-8' ) as f:
    text = f.read()

print( "Text length: ", len(text))
print( "----\n" )
print(text[0:100])
print( "----\n" )
print( "text[0:100] == text[:100] =>", text[0:100] == text[:100] )

Text length:  1115394
----

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You
----

text[0:100] == text[:100] => True


In [27]:
# create the vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("Vocab: ", ''.join(chars))
print("Vocab size: ", vocab_size)

Vocab:  
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocab size:  65


In [30]:
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

In [33]:
# encode takes a string and returns a list of integers
encode = lambda s: [ stoi[c] for c in s ]
decode = lambda l: ''.join( itos[i] for i in l )

In [34]:
encode('ball')

[40, 39, 50, 50]

In [35]:
decode([40, 39, 50, 50])

'ball'

In [36]:
decode(encode('this is a sentence'))

'this is a sentence'

In [39]:
# Encode the entire text data set and store it in a tensor
import torch
data = torch.tensor( encode(text), dtype=torch.long )
print(data.shape, data.dtype)

data[:10]

torch.Size([1115394]) torch.int64


tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])

In [54]:
decode(data[:62].tolist())

'First Citizen:\nBefore we proceed any further, hear me speak.\n\n'

In [82]:
# split the data up into train / validation
n = int( 0.9 * len(data) )
train_data = data[:n]
validation_data = data[n:]
print( len(train) + len(validation) == len( data ) )

True


In [83]:
block_size = 8
print( data[:block_size] )
print( data[1:block_size+1] )

tensor([18, 47, 56, 57, 58,  1, 15, 47])
tensor([47, 56, 57, 58,  1, 15, 47, 58])


In [84]:
x = data[:block_size]
y = data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print( f"Input: { context } - Target: { target }" )

Input: tensor([18]) - Target: 47
Input: tensor([18, 47]) - Target: 56
Input: tensor([18, 47, 56]) - Target: 57
Input: tensor([18, 47, 56, 57]) - Target: 58
Input: tensor([18, 47, 56, 57, 58]) - Target: 1
Input: tensor([18, 47, 56, 57, 58,  1]) - Target: 15
Input: tensor([18, 47, 56, 57, 58,  1, 15]) - Target: 47
Input: tensor([18, 47, 56, 57, 58,  1, 15, 47]) - Target: 58


In [85]:
print(len(data) - block_size)

1115386


In [86]:
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

(batch_size,)

(4,)

In [95]:
# Get 4 random block starting points
ix_temp = torch.randint(len(data) - block_size, (batch_size,))
print(ix_temp)
print(ix_temp[0])
print(train_data[0:10])
print(train_data[ix_temp[0]:ix_temp[0]+block_size])
print(train_data[ix_temp[0]+1:ix_temp[0]+block_size+1])



tensor([978324, 638409, 104713,  75569])
tensor(978324)
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])
tensor([12,  0,  0, 28, 39, 45, 43, 10])
tensor([ 0,  0, 28, 39, 45, 43, 10,  0])


In [104]:
def get_batch(split):
    data = train_data if split == 'train' else validation_data
    ix = torch.randint( len(data) - block_size, (batch_size,))
    x = torch.stack( [ data[i:i+block_size] for i in ix ] )
    y = torch.stack( [ data[i+1:i+1+block_size] for i in ix ] )
    return x, y

xb,yb = get_batch('train')
print(x.shape)
print(y.shape)
print(x)
print(y)

torch.Size([4, 8])
torch.Size([4, 8])
tensor([[47, 58,  1, 57, 43, 43, 51, 57],
        [43, 50,  1, 51, 43,  1, 58, 53],
        [39, 50, 50,  1, 42, 43, 57, 54],
        [ 1, 54, 56, 39, 63,  1, 63, 53]])
tensor([[58,  1, 57, 43, 43, 51, 57,  1],
        [50,  1, 51, 43,  1, 58, 53,  1],
        [50, 50,  1, 42, 43, 57, 54, 39],
        [54, 56, 39, 63,  1, 63, 53, 59]])


In [105]:
for b in range(batch_size): # Batch dimension
    for t in range(block_size): # Time dimension
        context = xb[b][:t+1]
        target = yb[b][t]
        print( f"Input: { context } - Target: { target }" )

Input: tensor([6]) - Target: 0
Input: tensor([6, 0]) - Target: 21
Input: tensor([ 6,  0, 21]) - Target: 44
Input: tensor([ 6,  0, 21, 44]) - Target: 1
Input: tensor([ 6,  0, 21, 44,  1]) - Target: 61
Input: tensor([ 6,  0, 21, 44,  1, 61]) - Target: 43
Input: tensor([ 6,  0, 21, 44,  1, 61, 43]) - Target: 1
Input: tensor([ 6,  0, 21, 44,  1, 61, 43,  1]) - Target: 61
Input: tensor([58]) - Target: 52
Input: tensor([58, 52]) - Target: 43
Input: tensor([58, 52, 43]) - Target: 57
Input: tensor([58, 52, 43, 57]) - Target: 57
Input: tensor([58, 52, 43, 57, 57]) - Target: 2
Input: tensor([58, 52, 43, 57, 57,  2]) - Target: 1
Input: tensor([58, 52, 43, 57, 57,  2,  1]) - Target: 57
Input: tensor([58, 52, 43, 57, 57,  2,  1, 57]) - Target: 43
Input: tensor([1]) - Target: 59
Input: tensor([ 1, 59]) - Target: 52
Input: tensor([ 1, 59, 52]) - Target: 39
Input: tensor([ 1, 59, 52, 39]) - Target: 41
Input: tensor([ 1, 59, 52, 39, 41]) - Target: 46
Input: tensor([ 1, 59, 52, 39, 41, 46]) - Target: 47

# Bigram Language Model

In [106]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets):
        # idx is the inputs
        # targets is the outputs
        # B -> Batch   -> Batch Size (4)
        # T -> Time    -> Block Size (8)
        # C -> Channel -> Vocab Size (65)
        logits = self.token_embedding_table(idx) # (B, T, C)

        return logits


m = BigramLanguageModel(vocab_size)
out = m(xb, yb)
print(out.shape)

torch.Size([4, 8, 65])


In [134]:
print( decode([15]) ) # The letter C

print( m.token_embedding_table.weight[15] ) # the weights for embedding 15 (C)

some_x = torch.tensor([ 15 ])
# Right now, y isn't doing anything in our forward function so we can change it with no impact
some_y = torch.tensor([ 22 ]) 
out = m( some_x, some_y ) 
print( out )

C
tensor([-0.2060,  1.5973,  0.1185, -1.2549,  1.3024,  0.4760, -0.8871,  1.3709,
        -1.9473, -0.8017, -1.3055, -0.4910,  0.4430,  0.2178, -0.3297, -0.0192,
         0.9225,  0.9187,  0.2998,  0.6106,  0.7791,  0.1237,  1.8620,  1.7080,
        -1.6045,  0.3338, -2.0513,  0.5923,  0.4880, -1.4055, -0.6686, -0.4831,
        -0.2298,  0.9043,  0.7631, -0.1606,  0.9156, -0.6908, -0.3065, -1.1809,
         0.8175, -2.0392,  0.1558, -0.2996, -0.5391, -0.3657,  0.8282, -0.4826,
         1.8330,  0.3421,  0.2154, -0.1029, -0.0946,  0.0070,  0.1484, -0.5403,
        -1.9312, -0.7858, -0.6731, -0.0901,  0.2598, -0.5349,  0.5812, -0.5356,
        -1.7944], grad_fn=<SelectBackward0>)
tensor([[-0.2060,  1.5973,  0.1185, -1.2549,  1.3024,  0.4760, -0.8871,  1.3709,
         -1.9473, -0.8017, -1.3055, -0.4910,  0.4430,  0.2178, -0.3297, -0.0192,
          0.9225,  0.9187,  0.2998,  0.6106,  0.7791,  0.1237,  1.8620,  1.7080,
         -1.6045,  0.3338, -2.0513,  0.5923,  0.4880, -1.4055, -0.6686

In [215]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx is (B, T) array of indices in the current context
        # targets is the outputs, also (B, T)
        # B -> Batch   -> Batch Size (4)
        # T -> Time    -> Block Size (8)
        # C -> Channel -> Vocab Size (65)
        logits = self.token_embedding_table(idx) # (B, T, C)

        if targets is None:
            loss = None
        else:
            # We have a multi-dimensional input (B, T, C) so this won't work
            #  the F.corss_entropy wants the C to be the second dimension, so (B, C, T)
            #loss = F.cross_entropy(logits, targets)
            # So, we have to reshape our logits
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # Stretch it out
            reshaped_targets = targets.view(B*T)
            loss = F.cross_entropy( logits, reshaped_targets )

            # Seems like you could also do this
            #reshaped_logits = logits.permute(0, 2, 1) 
            #loss = F.cross_entropy( reshaped_logits, targets )
            #print(loss)

        # Not the best code, returns different logits dimensions 
        #    depending on whether the targets were provided. If provided
        #    it returns the reshaped dimensions, otherwise, the original.
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context (i.e. the encoded string(s) we're trying to generate additional text for)
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            #   i.e. grab the logits representing the last character in the string represented by idx)
            #   i.e. plucks out the last element of the time dimension
            logits = logits[:, -1, :]
            # apply softmax to get probabilities
            probabilities = F.softmax(logits, dim=-1)

            # sample from the distribution
            idx_next = torch.multinomial(probabilities, num_samples=1)  # (B, 1)

            # append sampled index to the running sequence
            #   dim=1 means to use the T (time) dimension, since idx is (B, T)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
            

model = BigramLanguageModel(vocab_size)
# logits, loss = m(xb, yb)
# print(logits.shape)

# torch.zeros((1, 1), dtype=torch.long)
# torch.tensor([encode('your')])

out = model.generate( torch.tensor([encode('your')]), 20 )[0].tolist()
print( decode( out ) )
# m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


yourbM
3LUZurpqB&W-Z XiS

iy!FvS-:
jbcGpE,GIoyeS?WsU'KWKwuK:RDTsSxO.LuW;.AHoV3XiLo.KLmuFkdRbzSJaiSMK:-gyKQIE&WqFbljnhycvfvK:Uo


## Training the Bigram model

In [216]:
learning_rate = 1e-3
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [237]:
batch_size = 32
max_steps = 1000

for i in range(max_steps):
    xb,yb = get_batch('train')
    logits, loss = model(xb, yb)
    # Zero-out gradients from previous step
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    # This will update all of our parameters, magically!
    optimizer.step()
print(loss.item())

2.4422693252563477


In [241]:
out = model.generate( torch.tensor([encode('your')]), 20 )[0].tolist()
print( decode( out ) )
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


yourd ke s bith main JUS


S:
ONRIULO:
Nequ hickn araroug,


HAnor the frein se r s

LLe bethitcais icheldee sp-eende chth lec


# Math Trick for self-attention


We want tokens to only communicate with past tokens. One way is to take the averages of all the previous tokens (even though we lose the positional information).

In [313]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

### Inefficient approach

In [314]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C)) # x bag-of-words (used when averaging)
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

In [315]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [316]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

The first row is the same, but we see subsequent rows are the average of all previous rows.

In [317]:
# This should be the same as what we see in the last row
x[0].mean(dim=0)

tensor([-0.0341,  0.1332])

### Efficient Approach

In [318]:
torch.ones(3,3)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [319]:
# returns lower triangular portion of matrix
torch.tril(torch.ones(3,3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [320]:
temp = torch.tril(torch.ones(3,3))
print(temp)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])


In [321]:
torch.sum(temp)

tensor(6.)

In [322]:
torch.sum(temp, dim=0)

tensor([3., 2., 1.])

In [323]:
torch.sum(temp, dim=1)

tensor([1., 2., 3.])

In [324]:
torch.sum(temp, dim=1, keepdim=True)

tensor([[1.],
        [2.],
        [3.]])

In [325]:
temp = temp / torch.sum(temp, dim=1, keepdim=True)
print(temp)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


In [326]:
torch.manual_seed(1337)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b
print('a=')
print(a)
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])
--
c=
tensor([[5.0000, 7.0000],
        [3.5000, 3.5000],
        [4.0000, 3.3333]])


So...

In [327]:
wei = torch.tril( torch.ones(T,T) )

In [328]:
wei = wei / wei.sum(1, keepdim=True)

In [329]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [336]:
# x is from way at the top of this section

# (T, T) @ (B, T, C) -> 
# (B, T, T) @ (B, T, C) -> 
# (B, T, C)
xbow2 = wei @ x 

In [335]:
# compare inefficient method to efficient method
torch.allclose(xbow, xbow2)

True

### One More Method

In [344]:
torch.tril( torch.ones( T, T) )


tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [345]:
wei = torch.zeros((T,T))

In [346]:
wei = wei.masked_fill( tril == 0, float('-inf'))
print( wei )

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])


In [347]:
wei = F.softmax(wei, dim=-1)
print(wei)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [348]:
tril = torch.tril( torch.ones( T, T) )
wei = torch.zeros((T,T))
wei = wei.masked_fill( tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

# Self Attention

The current code just takes a simple average of current and past tokens using the lower triangular structure. But, we don't necessarily want this weighting to be uniform because some tokens will find other tokens more intersting and we may want them to have a higher affinity. A vowel may want constanants from the past, in a data-dependent way. 

This is what self-attention solves. 

Every single node/token at each position will emit two vectors: query and key.

Query = What am I looking for
Key = What do I contain

Affinities are the dot products between the queries and the keys. So, for a given node/token, you can take its query and dot product it with all the other keys for all the other nodes/tokens. That resulting dot product becomes the new "wei" (weights??).

If the key/query are aligned, they will interact in a high amount, and I'll learn more about that token than other tokens in the sequence.

We're now going to implement a single "head" of self attention.

In [363]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)
print("x shape: ", x.shape)

# let's see a single Head perform self-attention
# A linear layer will transform from the in_features to the out_features
head_size = 16
key = nn.Linear(in_features=C, out_features=head_size, bias=False)
query = nn.Linear(in_features=C, out_features=head_size, bias=False) 
value = nn.Linear(in_features=C, out_features=head_size, bias=False)

# (B, T, C) -> (B, T, head_size)
# In parallel, this produces a key/query for every token
#     so no communication has happened (yet)
k = key(x)
q = query(x)

# wei is now the affinities between the keys/queries
#  to do the multiplication, we have to transpose the last two dimensions
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) -> (B, T, T)

# for every batch we now have a (T, T) matrix representing the affinities

print( "k shape: ", k.shape)
print( "q shape: ", q.shape)

tril = torch.tril( torch.ones( T, T) )
# No longer using zeros
#wei = torch.zeros((T,T))

# This is used in a "decoder" where you're using an auto-regressive approach
#   so we don't talk to the future nodes. But if we wanted an "encoder", for
#   something like sentiment analysis for example, then we'd just need to 
#   delete this line so that nodes can all talk to each other.
wei = wei.masked_fill( tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

#out = wei @ x
v = value(x)
out = wei @ v


out.shape
wei[0]

x shape:  torch.Size([4, 8, 32])
k shape:  torch.Size([4, 8, 16])
q shape:  torch.Size([4, 8, 16])


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

This now tells us how much of the data to aggregate from tokens in the past.

x is thought of as private information for each token (the embedding value??). 
k is what each token has
q is what each token is looking for
v is what is communicated based on the affinities between different tokens

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [377]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei_naive = q @ k.transpose(-2, -1)

# During initialization we need wei to be diffuse
wei_scaled  = q @ k.transpose(-2, -1) * head_size**-0.5

In [373]:
k.var()

tensor(1.0966)

In [374]:
q.var()

tensor(0.9416)

In [375]:
wei_naive.var()

tensor(16.1036)

In [376]:
wei_scaled.var()

tensor(1.0065)

During initialization we want our values to be fairly diffuse, because if they're too big then they will give too much bias to the biggest term. For example:


In [378]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [379]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])