In [1]:
# getting the dataset for training 
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-05-15 15:07:17--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.2’


2025-05-15 15:07:18 (7.69 MB/s) - ‘input.txt.2’ saved [1115394/1115394]



In [2]:
# read and inspeect the text file
with open('input.txt', 'r') as f:
    text = f.read()

In [3]:
# exploring the txt file

# looking at first 1000 characters
print(text[:1000], '\n')

#number of characters in the dataset
print(f"length of the dataset: {len(text)}")

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [4]:
# vocabulary and the size of vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"vocabulary: {''.join(chars)}")
print(f"size of vocabulary: {vocab_size}")

vocabulary: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
size of vocabulary: 65


In [5]:
# tokenizing the text; ie convert the raw strings to some sequences of integers

# creating a mapping of characters to integers
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda st: [stoi[x] for x in st] # function of the string which takes the string and returns a list of integers
decode = lambda l: ''.join([itos[x] for x in l]) # function of the lost of integers where it takes each elements of the list and returns a character and joins them together

print(encode('hii there'))
print(decode(encode('hii there')))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [6]:
# encoding the entire dataset and storing it as a torch tensor
import torch
data = torch.tensor(encode(text), dtype = torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [7]:
print(data[:100])

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [8]:
# splitting the data into train and validation sets
n = int(0.9 * len(data)) # first 90% will be train, rest 10% will be validation set
train_data = data[:n]
val_data = data[n:]
# validation set will help us understand how much our model is overfitting and memorising 

In [9]:
# the entire text is never fed into the transformer in one go as it is computationally prohibitive
# chunking the dataset and sample these chunks into the transformer, thus training transformer on chunks at a time
block_size = 8 # chunk size
"""seeing first 9 chars
we are chunking for 8 chars. Since each preceding sequence of characters must predict for the next character. so, 18 would predict 47, 18 and 47 would predict 56.
So this way we have actually 8 individual training examples here
"""
train_data[:block_size + 1] 


tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [10]:
x = train_data[: block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1] #context is all preceding chars plus current character
    target = y[t] 
    print(f" Example {t+1}: When input: {context}, target: {target}")

# this masked type attention is done to make transformer used to seeing all kinds of lengths of context

 Example 1: When input: tensor([18]), target: 47
 Example 2: When input: tensor([18, 47]), target: 56
 Example 3: When input: tensor([18, 47, 56]), target: 57
 Example 4: When input: tensor([18, 47, 56, 57]), target: 58
 Example 5: When input: tensor([18, 47, 56, 57, 58]), target: 1
 Example 6: When input: tensor([18, 47, 56, 57, 58,  1]), target: 15
 Example 7: When input: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
 Example 8: When input: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58


In [11]:
# we stack the batches of chunks into stacks and then feed them into the transformer. 
# we will have many batches of many chunks of text that are all stacked up in a single tensor.
# we do this because GPUs are very good at paralle processing.
# These chunks are trained upon independently

# we will start sampling random locations in the dataset to pull chunks from 

torch.manual_seed(1337) # we are setting the seed for the random number generator so that the random locations pulled are same whenever we train the model
batch_size = 4 # how many independent sequences will we process in parallel
block_size = 8 # what is the maximum context length for predictions

def get_batch(split):
    # generate a small batch of data of inputs x and target y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data)-block_size, (batch_size,)) # we generate batch_size number of random offsets, ie 4 values. 
    # these 4 values should be between 0 and len(data)-block_size
    x = torch.stack([data[i: i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    # torch.stack to take all 4 of the 1-D tensors and stack them up as rows in a 4x8 tensor
    return x, y

xb, yb = get_batch('train')
print(f'inputs: {xb.shape} shape, \n{xb}')
print(f'\ntargets: {yb.shape} shape, \n{yb}')
print('----------------')
for b in range(batch_size):
    for t in range (block_size):
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f'when input is {context.tolist()}, the target is {target}')

inputs: torch.Size([4, 8]) shape, 
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])

targets: torch.Size([4, 8]) shape, 
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----------------
when input is [24], the target is 43
when input is [24, 43], the target is 58
when input is [24, 43, 58], the target is 5
when input is [24, 43, 58, 5], the target is 57
when input is [24, 43, 58, 5, 57], the target is 1
when input is [24, 43, 58, 5, 57, 1], the target is 46
when input is [24, 43, 58, 5, 57, 1, 46], the target is 43
when input is [24, 43, 58, 5, 57, 1, 46, 43], the target is 39
when input is [44], the target is 53
when input is [44, 53], the target is 56
when input is [44, 53, 56], the target is 1
when input is [44, 53, 56, 1], the target is 58
when 

In [12]:
# feeding the input into very simple bigram neural network

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)

"""
bigram model:
If you input 'h', it tries to predict 'e', 'i', etc., based only on 'h'.
It doesn't look at what came before 'h' (no 't', 's', etc. before it — just 'h').

# -------------------------------
# Bigram Model Explanation
# -------------------------------
# We're using an nn.Embedding(vocab_size, vocab_size) layer here, which may look like a typical embedding layer,
# but it's not used in the traditional sense.
#
# In most NLP models, embeddings represent tokens (e.g., characters or words) as low-dimensional dense vectors
# that capture semantic meaning and are used as input to deeper layers.
#
# However, in this bigram model, we use the embedding table to directly map each input token (an integer index)
# to a vector of size vocab_size that represents the raw logits (unnormalized scores) for the next character.
# So:
#     Input token index → Lookup corresponding row in embedding table → Row used as logits for next character
#
# We use nn.Embedding here simply because it's a convenient way to index a learnable weight matrix using integers.
# The output of this embedding is not a dense semantic vector — it's the actual logits for the next prediction step.

# Dimensions used:
# B = Batch size       → Number of sequences processed in parallel
# T = Time steps       → Number of tokens in each sequence
# C = Channels         → Size of output vector per token, here equal to vocab_size (i.e., number of possible next tokens)

# At this stage, the model has not been trained — so the logits it outputs are essentially random.
# Training (via loss and backpropagation) will adjust the embedding table so that each input token learns to
# predict the most likely next token based on character-level bigram statistics.

"""
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets = None):
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C) Batch: 4, time: 8, channels is vocab size: 65
        # when we pass idx here, every integer in our input is going to refer to this embedding table 
        # and is going to pluck out a row of that embedding table corresponding to its index
        if targets is None:
            loss=None
        else:            
            # pytorch is then going to arrange all of this into a batch X time X channel tensor
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # we are doing this because cross entropy loss function in pytorch expects a B,C,T tensor
            targets = targets.view(B*T) # right now the targets are of shape B,T, we are making it 1D.
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss # logits are the scores for the next character in the seq
        # we are predicting the next char just based on just individual identity of a single token.
        # currenly the tokens are not seeing any context or interacting with each other 
        # this can still do some predictions

    def generate(self, idx, max_new_tokens): # gives us a running stream of generations.
        # idx is the (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get predictions
            logits, loss = self(idx)
            # focus on only the last time step 
            logits = logits[:, -1, :] # becomes (B, C) for the T+1 char, basically we are plucking out thr last element in the time dimension.
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples = 1) # (B,1) in each of the batch dimensions, we will have a single prediction for what comes next 
            # append sampled index to the running sequence 
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
            # whatver is predicted is concatenated on top of previous idx along the time dimension. so basically this whole thing takes (B,T)
            # and makes it to (B,T+1), (B, T+2) and so on upto max_new_tokens.
        return idx

m = BigramLanguageModel(vocab_size)
logits,loss = m(xb, yb)
print(logits.shape)
print(loss)

idx_trial = torch.zeros((1,1), dtype=torch.long )# creating a tensor where batch is 1 and time is 1, and its value is 0. the datatype is integer
print(decode(m.generate(idx_trial, max_new_tokens = 100)[0].tolist())) # because generate works on the level of batches, we have to get the 0th row
# to get the single batch dimension that is getting output, 

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


In [13]:
# we can guess what the loss should be
# we are expecting -ln(1/65) ~ 4.17
"""You're using:
vocab_size = 65 → 65 possible characters
A randomly initialized model, so:
The logits (output of the embedding table) are random, and therefore
After softmax, the probability distribution over the next character is nearly uniform (i.e., ~1/65 for each character).
Cross-Entropy Loss:
Cross-entropy compares the predicted probability distribution (p̂) with the true distribution (p = one-hot).

If the model predicts all 65 characters equally likely (uniform distribution), and the true next character is 'e', then:

loss = − log(predicted prob of correct token) = − log(1/65) ≈ 4.17
So yes -ln(1/65) is the expected loss for a uniform random model.

so, what does loss tell us?
f you get a loss close to 4.17 before training:
It’s a good sanity check — your model is producing an almost uniform distribution, as expected.
If the loss is much lower, something might be wrong (like model memorizing already, or bad labels).
If the loss is higher, your softmax might be too “spiky” (concentrated on a few wrong tokens), which can happen if your logits are initialized poorly.
"""



"You're using:\nvocab_size = 65 → 65 possible characters\nA randomly initialized model, so:\nThe logits (output of the embedding table) are random, and therefore\nAfter softmax, the probability distribution over the next character is nearly uniform (i.e., ~1/65 for each character).\nCross-Entropy Loss:\nCross-entropy compares the predicted probability distribution (p̂) with the true distribution (p = one-hot).\n\nIf the model predicts all 65 characters equally likely (uniform distribution), and the true next character is 'e', then:\n\nloss = − log(predicted\xa0prob\xa0of\xa0correct\xa0token) = − log(1/65) ≈ 4.17\nSo yes -ln(1/65) is the expected loss for a uniform random model.\n\nso, what does loss tell us?\nf you get a loss close to 4.17 before training:\nIt’s a good sanity check — your model is producing an almost uniform distribution, as expected.\nIf the loss is much lower, something might be wrong (like model memorizing already, or bad labels).\nIf the loss is higher, your softma

In [14]:
# training the bigram model
# creating a pytorch optimizer object 
optimizer = torch.optim.AdamW(m.parameters(), lr = 1e-3)

In [16]:
batch_size = 32
# training for 100 epochs
for steps in range(10000):
    # sample a batch of data
    xb, yb = get_batch('train')
    #evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True) # zeroing out all gradients from previous step
    loss.backward() # getting the gradients for all of the params
    optimizer.step() # using the gradients to updte the params
print(loss.item())

2.4522743225097656


In [17]:
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long ), max_new_tokens = 300)[0].tolist()))


DUKIVisun Casshe wisthiot s.
LUK:

NGOLI io e alllker s?j$NCowens l het hislaspicobar, heay ind, cigigeluandac! thaforo nont
SLO:
Ange ive nn I ou m,
UCENTheanp'Lbet bazzl
TEEXNore t b'Thathon:
sous min'd ne st wousis s lingilo whee,
K:
Toow'e's,
D:
NGLEng, do te! ase may sin ceecate.
God? d
Aw ht h


## mathematical trick for self-attention

In [18]:
# toy example
torch.manual_seed(1337)
B,T,C = 4, 8, 2
x = torch.rand(B,T,C)


In [19]:
# we have upto 8 tokens in batch which are currently not interacting with each other
# to couple them in such a way that a token only interacts with its preceding tokens, because we want to predict the future tokens
# the easiest way for tokens to interact is taking average of all the preceding tokens and current token.
# it becomes the feature vector that summarises the current vector in contxt of its history
# this commuication is extremely lossy as it has lost all positional information about tokens.

# we want x[b, t] = mean_{i<t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(b):
    for t in range(T):
        xprev = x[b, :t+1] # (t, C) # t: how many preceding tokens
        xbow[b,t] = torch.mean(xprev, 0)  # averaging out over the 0th dimension ie averaging out T

In [20]:
x[0]

tensor([[0.0783, 0.4956],
        [0.6231, 0.4224],
        [0.2004, 0.0287],
        [0.5851, 0.6967],
        [0.1761, 0.2595],
        [0.7086, 0.5809],
        [0.0574, 0.7669],
        [0.8778, 0.2434]])

In [21]:
xbow[0]

tensor([[0.0783, 0.4956],
        [0.3507, 0.4590],
        [0.3006, 0.3156],
        [0.3717, 0.4108],
        [0.3326, 0.3806],
        [0.3953, 0.4140],
        [0.3470, 0.4644],
        [0.4134, 0.4368]])

In [22]:
# we can see that in x[0] and xbow[0], ,the first row is the same since its the same token
#(bow is bag of words which usually represents a simple average of token representations)

In [32]:
# this way of calculating average is very inefficient, instead we can use matrices
# matrix multiplication
torch.manual_seed(42)
a = torch.ones(3,3)
b= torch.randint(0,10, (3,2)).float() # a matrix of size 3,2 and it should have random values between 0 and 10
c = a @ b
print(f"a = {a}\n--------\nb = {b}\n----------\nc = {c}\n")

# using matrix multiplication to get average over tokens
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
print(f'Triangular matrix a: {a}\n')

a = a / torch.sum(a, 1, keepdim=True)
b = b= torch.randint(0,10, (3,2)).float()
c = a @ b
print(f"a = {a}\n--------\nb = {b}\n----------\nc = {c}\n")
# we can see in c that first row is average of just first row of b, second row is avrage of first and second row of b and so on

a = tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
--------
b = tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
----------
c = tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])

Triangular matrix a: tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

a = tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--------
b = tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
----------
c = tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])



In [33]:
wei = torch.tril(torch.ones(T,T))
wei = wei/wei.sum(1, keepdim = True)
xbow2 = wei@x # (B, T, T) @ (B, T, C)---> (B, T, C)
torch.allclose(xbow, xbow2)

False

In [37]:
xbow[0], xbow2[0] # they are the same 
# we are leveraging batch matrix multiplication to basically get a weighted average of all tokens in a batch
#all the tokens post current tooken are weighted down to 0 using a triangular matrix

(tensor([[0.0783, 0.4956],
         [0.3507, 0.4590],
         [0.3006, 0.3156],
         [0.3717, 0.4108],
         [0.3326, 0.3806],
         [0.3953, 0.4140],
         [0.3470, 0.4644],
         [0.4134, 0.4368]]),
 tensor([[0.0783, 0.4956],
         [0.3507, 0.4590],
         [0.3006, 0.3156],
         [0.3717, 0.4108],
         [0.3326, 0.3806],
         [0.3953, 0.4140],
         [0.3470, 0.4644],
         [0.4134, 0.4368]]))

In [41]:
# another way to get a masked matrix of weights to get average of preceding tokens
# method using softmax
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril==0, float('-inf')) # wherever tril is 0, we make that -inf in the wei
print(f'wei matrix after masked fill: {wei}\n')
wei = F.softmax(wei, dim=-1)
print(f'wei after applying softmax: {wei}\n') 
xbow3 = wei@x
torch.allclose(xbow, xbow3)

wei matrix after masked fill: tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

wei after applying softmax: tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.125

False

In [47]:
# masked self attention without averaging but using Q,K,V
# ie getting wei which is dependent on data
"""
how this works:
every single node/token will emit two vectors: query and key
query: what am i looking for
key: what do i contain
the way we get affinities between tokens now is we do a dot product between keys and tokens
"""
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# lets see a single head perform self attention 
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias = False)
k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
v = value(x)
# when i forward these linears on top of x, all of the positions in the b,t arrangement in parallel and independently produce a key and a query
# the communication happens now
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

# now we do masking
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [48]:
wei[0] # wei now tells us how much information to aggregate from all of the tokens 

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

In [49]:
"""
notes about self attention:
- there is no notion of position, attention simply acts over a set of vectors. thats why we need to positionally encode the input tokens.
- each chunk over batch dimension is processed completely independently and never talk to each other.
- it basically means matrix multiplication is aplied in parallel across the batch dim
- if our batch size is 4, that means that we have separate pools of 8 nodes and those nodes only talk to each other.

difference between self and cross attention
self-attention: keys, queries and values are all coming from the same source, from x
cross-attention: queries are coming from x but the keys and the values come from separate source.
"""

'\nnotes about self attention:\n- there is no notion of position, attention simply acts over a set of vectors. thats why we need to positionally encode the input tokens.\n- each chunk over batch dimension is processed completely independently and never talk to each other.\n- it basically means matrix multiplication is aplied in parallel across the batch dim\n- if our batch size is 4, that means that we have separate pools of 8 nodes and those nodes only talk to each other.\n'

In [62]:
# scaled attention normalisation method
# we have gaussian distributed values with unit variance and zero mean for k and q
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1)

print(f"variance without scaling: {k.var(), q.var(), wei.var()}")

#variance for wei (if calculated naively) is of the order of head_size

# but if we multiply by head_size**-0.5
wei = q @ k.transpose(-2, -1) * head_size**-0.5
print(f"variance with scaling: {k.var(), q.var(), wei.var()}")

# we need to do this because wei gets fed into the softmax function, so it needs to be fairly diffused. 
# otherwise, if wei takes on very posisitive or very negative numbers inside it, softmax will actially converge towards one hot vectors
# this would mean eery node would end up aggregating information from only one single other node
# example
print(f"softmax when tensor values are diffused: {torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)}")
print(f"softmax when tensor values are spiky and go very low or very high: {torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*20, dim=-1)}")

variance without scaling: (tensor(1.0172), tensor(0.9599), tensor(12.9863))
variance with scaling: (tensor(1.0172), tensor(0.9599), tensor(0.8116))
softmax when tensor values are diffused: tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
softmax when tensor values are spiky and go very low or very high: tensor([3.2932e-04, 8.1630e-07, 1.7980e-02, 8.1630e-07, 9.8169e-01])
