# Transformers - Attention is all you need

In [14]:
!pip install sentencepiece

Collecting sentencepiece
  Obtaining dependency information for sentencepiece from https://files.pythonhosted.org/packages/de/42/ae30952c4a0bd773e90c9bf2579f5533037c886dfc8ec68133d5694f4dd2/sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl.metadata
  Downloading sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (7.7 kB)
Downloading sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.2.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [15]:
import sentencepiece as spm
import torch
import torch.nn as nn
from torch.nn import functional as F

In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-03-26 16:26:03--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-03-26 16:26:03 (9.13 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
with open('input.txt', 'r', encoding= 'utf-8') as f:
    text = f.read()

In [7]:
print(f"Len of dataset characters : {len(text)}")

Len of dataset characters : 1115394


In [10]:
print(text[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


## Make a vocabulary list of the characters used : (will not use this)

In [11]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


### Make Encoder/Decoder (will not use this)


In [13]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("Hello World!"))
print(decode(encode("Hello World!")))

[20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42, 2]
Hello World!


In [37]:
vocab_size = 5000

In [None]:
vocab_size = 5000
spm.SentencePieceTrainer.Train('--input=input.txt --model_prefix=spm --vocab_size=vocab_size')

sentencepiece_trainer.cc(178) LOG(INFO) Running command: --input=input.txt --model_prefix=spm --vocab_size=5000
sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: input.txt
  input_format: 
  model_prefix: spm
  model_type: UNIGRAM
  vocab_size: 5000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_pie

In [23]:
s = spm.SentencePieceProcessor(model_file='spm.model')
for n in range(5):
    a = s.encode(text[:100], out_type=str, enable_sampling=True, alpha=0.1, nbest_size=-1)
print(a)

['▁First', '▁', 'C', 'it', 'i', 'zen', ':', '▁Be', 'f', 'or', 'e', '▁we', '▁pr', 'o', 'ce', 'e', 'd', '▁an', 'y', '▁f', 'u', 'r', 'ther', ',', '▁', 'he', 'a', 'r', '▁me', '▁sp', 'eak', '.', '▁All', ':', '▁Speak', ',', '▁sp', 'eak', '.', '▁First', '▁Citizen', ':', '▁You']


### We will use the spm SentencePieceProcessor to encode/decode

### Encode all the text

In [28]:
data = torch.tensor(s.encode(text, out_type=int, enable_sampling=True, alpha=0.1, nbest_size=-1), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([535228]) torch.int64
tensor([ 171,  370,    4, 1121,   14,  319,   87,  615,  624,  699,   45,  224,
        1068,    3,   14,  363, 1416,   14, 1203,  151,    5,   65,  134,  134,
           4, 1113,    3,  151,    5,  131,  547,   70,  370,    4,  118,   15,
         181,   51,   14,  170,   87,    6, 1017,  637,   17,   14,  170, 2099,
          10, 1193,   87,  123,   10,  676,  717, 1034,   19,   14,  159,   61,
           4,  163,   87,    6, 1017,  402,   87,   17,    5,  992,  231,  272,
         134,  637,   17,    5,  131,  547,    6,   47,   14,  488,  523, 1976,
          75,    4,   14,  145,  547,   70,    3,   16,  108, 1809,  527,   14,
         892,   14,  396, 1123,  353,  867, 1030,  272,    8, 1788, 1164,  271,
           5,   14,  159,  134,  134,    4,  620,   87,  108,    9,   47,    3,
         923,   87,   14,  417,   75,  272,  319,    9,   47,    5,  171,   14,
         488,  401,   47,  401, 1781,    4,  269,  102,  502,   37,    3,   14,
       

### Split into train and val

In [29]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

Size of the block of text used to train

In [30]:
block_size = 8
train_data[:block_size+1]

tensor([ 171,  370,    4, 1121,   14,  319,   87,  615,  624])

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input is {context}, the target is : {target}")

When input is tensor([171]), the target is 370
When input is tensor([171, 370]), the target is 4
When input is tensor([171, 370,   4]), the target is 1121
When input is tensor([ 171,  370,    4, 1121]), the target is 14
When input is tensor([ 171,  370,    4, 1121,   14]), the target is 319
When input is tensor([ 171,  370,    4, 1121,   14,  319]), the target is 87
When input is tensor([ 171,  370,    4, 1121,   14,  319,   87]), the target is 615
When input is tensor([ 171,  370,    4, 1121,   14,  319,   87,  615]), the target is 624


In [35]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    # generate a smal batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs')
print(xb.shape)
print(xb)
print('targets')
print(yb.shape)
print(yb)

print('--------')

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"When input is {context.tolist()}, the target is : {target}")

inputs
torch.Size([4, 8])
tensor([[2264,  170, 3766,    5,   14,   74,  163,  368],
        [ 402,   54,   80,  109,    3,  104,  275,  399],
        [ 923, 1525,    3,  923,  401,   47,  363,   62],
        [ 891, 1031,   27,   70,  477,   17,   87,    8]])
targets
torch.Size([4, 8])
tensor([[ 170, 3766,    5,   14,   74,  163,  368,  965],
        [  54,   80,  109,    3,  104,  275,  399,  142],
        [1525,    3,  923,  401,   47,  363,   62,  851],
        [1031,   27,   70,  477,   17,   87,    8,   14]])
--------
When input is [2264], the target is : 170
When input is [2264, 170], the target is : 3766
When input is [2264, 170, 3766], the target is : 5
When input is [2264, 170, 3766, 5], the target is : 14
When input is [2264, 170, 3766, 5, 14], the target is : 74
When input is [2264, 170, 3766, 5, 14, 74], the target is : 163
When input is [2264, 170, 3766, 5, 14, 74, 163], the target is : 368
When input is [2264, 170, 3766, 5, 14, 74, 163, 368], the target is : 965
When input

## Lets create a simple bigram language model :

In [47]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self,vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and target are both (B,T), tensors of integers
        logits = self.token_embedding_table(idx) #(B,T,C) = batch (4) * time (8) * channel (vocab_size)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets) #Pytorch expect (B,C,T)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)  
        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb,yb)
print(logits.shape)  
print(loss)

idx = torch.zeros((1,1), dtype=torch.long)
print(s.decode(m.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 5000])
tensor(9.4888, grad_fn=<NllLossBackward0>)
 ⁇  Cheer chafe GrumioVOL fond weep trodage inform empty affect disguise aught Volscian Me executioner Pale volume bestow suddenlyread wherein date therefore blot Bianca shalt character beginningi pilgrimage writvoid canst spoken rein custom dreams Minola poHA expect Lodowick calf thy prescri fe Fortune joy destin jade accused Rome Condemnplotted depend wear damned slaughterthat state shelter mistrust corn whisper strait Offic grave defence Music puritIN suffice loss dastard unseen lords submissi already lust school supposed entertain verity skull Master rich reckon infringe Greekg did forsake home woful plainly hurt morrow calm choler


## Lets train the model

In [49]:
# create a pytorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [57]:
batch_size = 32
for steps in range(1000):
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    print(loss.item())

7.22401237487793
7.447227478027344
7.111723899841309
7.08571720123291
7.330059051513672
7.348686695098877
7.220452308654785
7.219447612762451
7.108402729034424
7.224458694458008
7.388709545135498
7.3378777503967285
7.273087501525879
7.1875128746032715
7.053735733032227
7.448727130889893
7.173687934875488
7.078999996185303
7.227609157562256
7.218874931335449
7.308000087738037
7.1261210441589355
7.218200206756592
7.220412731170654
7.256505489349365
7.099203586578369
7.247226715087891
7.046177387237549
7.226519584655762
7.234792232513428
7.182720184326172
7.253887176513672
7.135412216186523
7.270850658416748
7.27522611618042
7.2578630447387695
7.286032676696777
7.0365376472473145
7.152425765991211
7.200508117675781
7.277250289916992
7.279853343963623
7.245128154754639
7.217872142791748
7.260928153991699
7.193966865539551
7.218711853027344
7.349695682525635
7.132151126861572
7.2552385330200195
7.1605048179626465
7.26513147354126
7.273092269897461
7.220922470092773
7.247310161590576
7.19202

In [None]:
print(s.decode(m.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=400)[0].tolist()))

## Mathematical trick in self-attention

In [58]:
# cosider the following toy example:

torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [None]:
# version 1
# we want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros(B,T,C) # bow = bag of words (averaging)
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

### This can be really efficient using matrix multiplications

In [None]:
# version 2

wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B,T,T) @ (B,T,C) -----> (B,T,C)

torch.allclose(xbow,xbow2) # identicals

True

In [73]:
torch.tril(torch.ones(3,3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [74]:
xbow[0], xbow2[0]

(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

In [None]:
# version 3 : use Softmax
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf')) # tokens from the past cannot communicate with future
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow,xbow3)

True

In [77]:
torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('b=')
print(b)
print('---')
print('c=')
print(c)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [76]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('b=')
print(b)
print('---')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [None]:
# version 4 : self-attention
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channel
x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B,T,16)
q = query(x) # (B,T,16)
wei = q @ k.transpose(-2, -1) # (B,T,16) @ (B, 16, T)  ------> (B, T, T)  

tril = torch.tril(torch.ones(T,T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
#out = wei @ x

out.shape

torch.Size([4, 8, 16])

*** Notes: ***
- Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionnally encodes tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modelling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. 

In [99]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2,-1) * head_size**-0.5 # scaling to controle the variance

In [100]:
k.var()

tensor(1.0632)

In [101]:
q.var()

tensor(0.9891)

In [102]:
wei.var()

tensor(0.9755)

In [103]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [104]:
wei[[0]]

tensor([[[-0.1023,  1.3987,  1.0764, -0.1497, -0.2601,  1.0896, -0.7415,
           0.9481],
         [-0.5935,  0.1866,  0.8442, -2.6846, -0.5995, -0.1049, -0.2314,
          -0.5133],
         [-1.1968, -0.6008, -0.7900,  1.1396,  2.5007,  0.5416,  0.2368,
          -0.2689],
         [-0.2726, -0.0992,  0.4969,  0.5409,  0.5688,  0.5101,  0.9770,
           1.1494],
         [ 0.9239,  0.0869,  0.7659, -1.2412, -1.0235, -0.1439,  1.1736,
          -0.8299],
         [-0.2004,  0.1792,  1.2278, -1.0577, -0.0860,  0.4595, -0.2013,
           0.8051],
         [ 0.4227,  1.0302,  0.4409, -2.7691,  0.6206,  0.8169, -1.8316,
          -1.0899],
         [-0.0164,  1.5899,  2.0445, -2.2490, -0.3421,  1.0581, -1.4024,
           0.3886]]])

In [88]:
out[0][0]

tensor([ 0.1808, -0.0700, -0.3596, -0.9152,  0.6258,  0.0255,  0.9545,  0.0643,
         0.3612,  1.1679, -1.3499, -0.5102,  0.2360, -0.2398, -0.9211,  1.5433,
         1.3488, -0.1396,  0.2858,  0.9651, -2.0371,  0.4931,  1.4870,  0.5910,
         0.1260, -1.5627, -1.1601, -0.3348,  0.4478, -0.8016,  1.5236,  2.5086],
       grad_fn=<SelectBackward0>)