In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x10cf8a2d0>

In [5]:
from types import SimpleNamespace

cfg = SimpleNamespace()
cfg.device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Tokenizer

In [46]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O input_shakespear.txt

In [7]:
import torch
from datasets import load_dataset

class TextProcessor:
    def __init__(self, text=None):
        self.text = text
        if text:
            self.chars = sorted(list(set(text)))
            self.vocab_size = len(self.chars)
            self.stoi = {ch: i for i, ch in enumerate(self.chars)}
            self.itos = {i: ch for i, ch in enumerate(self.chars)}
            self.tokenized = torch.tensor(self.encode(text), dtype=torch.long)
    
    def encode(self, s):
            if isinstance(s, str):
                return torch.tensor([self.stoi[c] for c in s], dtype=torch.long)
            elif isinstance(s, list):
                return torch.tensor([self.stoi[c] for c in s], dtype=torch.long)
            else:
                raise TypeError("Input should be a string or a list of characters")
            
    def decode(self, l):
        if isinstance(l, torch.Tensor):
            l = l.tolist()
        if isinstance(l, list):
            return ''.join([self.itos[i] for i in l])
        else:
            raise TypeError("Input should be a tensor or a list of integers")
    
    @classmethod
    def from_imdb(cls):
        # Load the IMDb dataset
        # dataset = load_dataset("imdb")
        # texts = dataset['train']['text']
        # combined_text = ' '.join(texts)

        with open('input_shakespear.txt', 'r', encoding='utf-8') as f:
            combined_text = f.read()

        return cls(combined_text)

# Example usage:
text = "Hello World!"
tokz = TextProcessor(text)
print("Vocabulary Size:", tokz.vocab_size)
print("Encoded Text:", tokz.encode(text))
print("Decoded Text:", tokz.decode(tokz.encode(text)))

# Using the IMDb dataset
tokz = TextProcessor.from_imdb()
print("IMDb Vocabulary Size:", tokz.vocab_size)
print("Sample Encoded IMDb Text:", tokz.encode(tokz.text[:50]))
print("Sample Decoded IMDb Text:", tokz.decode(tokz.encode(tokz.text[:50])))

Vocabulary Size: 9
Encoded Text: tensor([2, 5, 6, 6, 7, 0, 3, 7, 8, 6, 4, 1])
Decoded Text: Hello World!
IMDb Vocabulary Size: 60
Sample Encoded IMDb Text: tensor([ 6,  6, 11,  9, 11, 13,  6,  9, 16,  6,  9, 10,  1,  9,  9, 19, 11, 16,
        19, 10, 14,  6,  6,  1,  1, 40, 51, 51, 47, 50, 19,  8,  8, 49, 34, 54,
         7, 39, 41, 51, 40, 52, 35, 52, 50, 38, 49, 36, 46, 45])
Sample Decoded IMDb Text: --2024-07-01 00:27:15--  https://raw.githubusercon


  self.tokenized = torch.tensor(self.encode(text), dtype=torch.long)


In [8]:
# Train and test splits
tokz = TextProcessor.from_imdb()
n = int(0.9*len(tokz.tokenized)) # first 90% will be train, rest val
train_data = tokz.tokenized[:n]
val_data = tokz.tokenized[n:]

train_data, val_data

  self.tokenized = torch.tensor(self.encode(text), dtype=torch.long)


(tensor([ 6,  6, 11,  ..., 10,  9,  9]),
 tensor([ 9, 24,  1,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  1,  7,  7,  7,  7,
          7,  7,  7,  7,  7,  7,  1,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  1,
          7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  1,  7,  7,  7,  7,  7,  7,  7,
          7,  7,  7,  1, 18, 15,  2,  1, 15,  7, 10, 13, 26,  1,  9, 50,  0,  1,
          1, 10,  9, 14,  9, 24,  1,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  1,
          7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  1,  7,  7,  7,  7,  7,  7,  7,
          7,  7,  7,  1,  7,  7,  7,  7,  7,  7,  7,  7,  7,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1, 10,  9,  9,  2,  1, 14,  7, 16, 18, 26, 20,
          9,  7, 11, 50,  0,  0, 11,  9, 11, 13,  6,  9, 16,  6,  9, 10,  1,  9,
          9, 19, 11, 16, 19, 10, 14,  1,  3, 14,  7, 15, 15,  1, 26, 21,  8, 50,
          4,  1,  6,  1, 58, 41, 45, 47, 52, 51,  7, 51, 55, 51, 59,  1, 50, 34,
         53, 38, 37,  1, 32, 10, 10, 10, 14, 12, 18, 13,  8, 10, 10,

In [9]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - cfg.block_size, (cfg.bs,))
    x = torch.stack([data[i:i+cfg.block_size] for i in ix])
    y = torch.stack([data[i+1:i+cfg.block_size+1] for i in ix])
    x, y = x.to(cfg.device), y.to(cfg.device)
    return x, y

cfg.bs = 32
cfg.block_size = 8

len(get_batch('train')), get_batch('train')[0].shape, get_batch('train')[1].shape

(2, torch.Size([32, 8]), torch.Size([32, 8]))

In [10]:
# X and Y
tokz.decode(get_batch('train')[0][0]), tokz.decode(get_batch('train')[1][0])

('arpathy/', '.... ...')

# Model

### Forward pass

In [11]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_sz):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_sz, vocab_sz)

    def forward(self, idx, targets=None):
        print(idx.shape)
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)
        print(logits.shape)

        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)
        return logits, loss


xb, yb = get_batch('train')
model = BigramLanguageModel(tokz.vocab_size)
# m = model.to(cfg.device)
logits, loss = model(xb, yb)
# print(logits.shape)
# print(loss)

assert logits.shape == (cfg.bs * cfg.block_size, tokz.vocab_size), logits.shape


torch.Size([32, 8])
torch.Size([32, 8, 60])


### Generating Text

In [12]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # this is a table of size (vocab_size, vocab_size)
        # each row corresponds to a token
        #           token1,            token2,           token3,           ...
        # token1 -> [p(token1|token1), p(token2|token1), p(token3|token1), ...]
        # token2 -> [p(token1|token2), p(token2|token2), p(token3|token2), ...]
        # token3 -> [p(token1|token3), p(token2|token3), p(token3|token3), ...]
        # ...
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)


    def forward(self, token_id, targets=None):
        # /!\ these comments explain the shape during training, for inference it's a bit different

        # token_numbers shape (batch_size, block_size) where block_size value is the token number from the vocab
        #                     (32, 8)

        logits = self.token_embedding_table(token_id)
        # logits shape is (batch_size, block_size, vocab_size)
        #                 (32, 8, 65)
        # it means for each a batch
        # we have a sequence of 8 positions
        # and for each position, what is the probability of each token in the vocab

        # karpathy call the logits BTC        
        # (B, T, C) means = (batch, time, channels)
        # B = batch size, T = position, C = channels
        assert logits.shape[2] == tokz.vocab_size

        if targets is None:
            loss = None
        else:            
            # batch, time, channels
            # why do we call them time???? and channells???
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


xb, yb = get_batch('train')
model = BigramLanguageModel(tokz.vocab_size)

# Let's look at the first pred
context = torch.zeros((1, 1), dtype=torch.long, device=cfg.device)

# We need to pass a context/sentence to autocomplete
print(tokz.decode(model.generate(context, max_new_tokens=500)[0].tolist()))

# as expected we got garbage!



Mn..-865]K0]hg[heqKi1uk-)ep6Kb8u%K16qiPiwd,PSyuidl64
iep[’=:,52q[(b3e-]yutRs3pK
oHv)]xOd‘e=Toe34Ta)P[Mw:COq0]Bg‘b(‘|‘q1Pdw6O7b1mSh)gg39g12nRm%t’/16o8p(q%]
i’nnOmr9aT-kM
‘(a6LvS5‘Lba=7.)y[c=i9T55LLcybt83lse61.BrL%KHiMwpdt%q4sbRy(q(tlL’C’98i1H=’B9tp:ei2B2o554t|agx1atu96e4B5[MbrlhT0Tc(08BrtO-vSi. gT8PK4O187b7Kb)r-%5’6u96Tgep/d Mlqo]2]
-8
SB2.][ L‘ktPeHKi]ea]risC/xtt5ce0(8ml‘5Klw8hmB5)3RS3psvh3‘,d4(m:Rt|tKLMdvcBxbiO‘’sx
uqo6lHO%64//L(lqed%gna‘:Cqh0S5=B:4[sic)]31o7‘y=’7M7xlLt%vC,R.lmmC5.T(==o%bl
8v,v


# Training

In [13]:
cfg.lr = 1e-2
cfg.epochs = 3000
cfg.eval_interval = 300
cfg.eval_iters = 200 

class Trainer:
    def __init__(self, model, cfg, get_batch, tokz):
        self.model = model
        self.cfg = cfg
        self.get_batch = get_batch
        self.tokz = tokz
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr)
        self.model = self.model.to(cfg.device)

    @torch.no_grad()
    def estimate_loss(self):
        out = {}
        self.model.eval()
        for split in ['train', 'val']:
            losses = torch.zeros(self.cfg.eval_iters)
            for k in range(self.cfg.eval_iters):
                X, Y = self.get_batch(split)
                logits, loss = self.model(X, Y)
                losses[k] = loss.item()
            out[split] = losses.mean()
        self.model.train()
        return out

    def train(self):
        for iter in range(self.cfg.epochs):
            if iter % self.cfg.eval_interval == 0:
                losses = self.estimate_loss()
                print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

            xb, yb = self.get_batch('train')
            logits, loss = self.model(xb, yb)
            self.optimizer.zero_grad(set_to_none=True)
            loss.backward()
            self.optimizer.step()

        context = torch.zeros((1, 1), dtype=torch.long, device=self.cfg.device)
        print(self.tokz.decode(self.model.generate(context, max_new_tokens=500)[0].tolist()))


trainer = Trainer(model, cfg, get_batch, tokz)
trainer.train()

step 0: train loss 4.1092, val loss 4.1047
step 300: train loss 1.4260, val loss 2.1074
step 600: train loss 1.1929, val loss 2.0399
step 900: train loss 1.1302, val loss 2.0746
step 1200: train loss 1.1093, val loss 2.1169
step 1500: train loss 1.1159, val loss 2.1243
step 1800: train loss 1.1004, val loss 2.1240
step 2100: train loss 1.1166, val loss 2.1563
step 2400: train loss 1.0976, val loss 2.2202
step 2700: train loss 1.0905, val loss 2.1834

 .0spspercom/rpaw. .................26M 43::500s
   (1M  ...................4% 67.360cont  . .  .... 60K .......... . 50comatentecoto::8M 243::/m/coma/m husengintingt........... 0: (rchubus
 ..0s
 ..... 9% raw.... 0K   ............50s
  8....  ............ 4 6M 9%  0s
 ........t, ............... 7..... ............ ..   .................. 1% 7............. s
Rera/r-rconng . .....0sequs
 00K 80K ............ ...... 0s

 4|::4 .. .4, ............. 4.....................  . ....... 0K .. 41 .......


# Giving Context To Tokens

### Adding Embedding (latent factors?) For Each Token

In [14]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd):
        super().__init__()
        # we are not doing vocab size to vocab size anymore, we are doing vocab size to n_embd
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        # then we need to project it back to vocab size
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, token_id, targets=None):
        # we are not getting logits directly but we are getting the embeddings
        tok_emb = self.token_embedding_table(token_id) # careful (B,T,C) the C here is an embedding C
        
        # project it back to vocab size
        logits = self.lm_head(tok_emb) # careful (B,T,vocab_size) the C here is the token number in the vocab size?

        # at this point logits represents the probability of each token in the vocab
        if targets is None:
            loss = None
        else:            
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss


### Adding Positional Encoding

Since transformers do not inherently capture the order of tokens (unlike RNNs), another structure is required to capture the spatial information of the tokens

In [15]:
# Before that, let's recap what we pass to our model:

# we pass a batch of token ids (sentences) for each position
xb = get_batch('train')[0]
xb.shape

torch.Size([32, 8])

In [16]:
# and one row of token ids (sentence) looks like this
x = get_batch('train')[0][0]
x.shape, x, tokz.decode(x) 


(torch.Size([8]), tensor([40, 52, 35, 52, 50, 38, 49, 36]), 'hubuserc')

In [17]:
# let's forget about the batch for now and zoom into ONE SINGLE TOKEN

n_embd = 20

tok = x[0]
print('tok id is', tok.item(), tok.shape, '\n')

tok_emb = nn.Embedding(tokz.vocab_size, n_embd)
print('each token has 32 hidden properties', tok_emb)
print('get the 32 hidden properties for token', tok.item(), '-> tok * token embeddings ->', tok_emb(tok).shape)

print('')
pos_emb = nn.Embedding(8, n_embd)
print('each position has 32 hidden properties', pos_emb)
print('get the 32 hidden properties for position 1 to 8', '-> [0, ..., 7] * pos embeddings ->', pos_emb(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], device=cfg.device)).shape)


tok id is 40 torch.Size([]) 

each token has 32 hidden properties Embedding(60, 20)
get the 32 hidden properties for token 40 -> tok * token embeddings -> torch.Size([20])

each position has 32 hidden properties Embedding(8, 20)
get the 32 hidden properties for position 1 to 8 -> [0, ..., 7] * pos embeddings -> torch.Size([8, 20])


Finally combine the two embeddings

In [18]:
res = tok_emb(tok) + pos_emb(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], device=cfg.device))
print('identity and position info ->', res.shape)

identity and position info -> torch.Size([8, 20])


And all of this works exactly the same at the batch level (instead of a token level) thanks to the magic of vectorization.
Even without changing the layers shapes!

In [19]:
# let's forget about the batch for now and zoom into ONE SINGLE TOKEN
print('each token has 32 hidden properties', tok_emb)
print('get the 32 hidden properties for token', tok.item(), '-> tok * token embeddings ->', tok_emb(xb).shape)

print('')
print('each position has 32 hidden properties', pos_emb)
print('get the 32 hidden properties for position 1 to 8', '-> [0, ..., 7] * pos embeddings ->', pos_emb(torch.arange(8, device=cfg.device)).shape)

print('')
res = tok_emb(xb) + pos_emb(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], device=cfg.device))
print('identity and position info ->', res.shape)


each token has 32 hidden properties Embedding(60, 20)
get the 32 hidden properties for token 40 -> tok * token embeddings -> torch.Size([32, 8, 20])

each position has 32 hidden properties Embedding(8, 20)
get the 32 hidden properties for position 1 to 8 -> [0, ..., 7] * pos embeddings -> torch.Size([8, 20])

identity and position info -> torch.Size([32, 8, 20])


In [20]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.positional_embedding_table = nn.Embedding(block_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
 
    def forward(self, token_id, targets=None):
        B,T = token_id.shape

        # get embedding for token i
        tok_emb = self.token_embedding_table(token_id) # careful (B,T,C) the C here is an embedding C

        # we are adding the positional embeddings
        pos_emb = self.positional_embedding_table(torch.arange(T, device=token_id.device)) # careful (T,C) the C here is an embedding C

        # combine both identity and positional embeddings
        # thanks to broadcasting, pos_emb will be added to each token in the sequence
        x = tok_emb + pos_emb # (B,T,C)

        # plug that into a linear layer that will project it back to vocab size
        logits = self.lm_head(x) # careful (B,T,vocab_size) so the C here is the token number in the vocab size

        if targets is None:
            loss = None
        else:            
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

In [21]:
trainer = Trainer(model, cfg, get_batch, tokz)
trainer.train()

step 0: train loss 1.0779, val loss 2.2225
step 300: train loss 1.0811, val loss 2.3513
step 600: train loss 1.0884, val loss 2.4847
step 900: train loss 1.0913, val loss 2.4882
step 1200: train loss 1.0845, val loss 2.5358
step 1500: train loss 1.0815, val loss 2.5726
step 1800: train loss 1.0730, val loss 2.6431
step 2100: train loss 1.0957, val loss 2.6723
step 2400: train loss 1.0968, val loss 2.6715
step 2700: train loss 1.0831, val loss 2.7073

 ..........  8......... ........... ....cthuses
 ting ...gtt.................  ... ............... .. 9150congines
 26:15.... 1% . ....... 6M ......... 2%  0cont/t.. .....4...  .. ......g s
HTP ......c0ctintit.. ........................ ...........
HTTTP .ctenpububus
 54500com) [txt................... ............ ... ............. 54   0s
   60K  .8%  ..........1525000 ....  1% 247........00K s
Lentxtom) ......40s
 .... 8%   ...00K txthubues
 rc0K ... ..... 87. ................ ............. 2M


# Self-Attention

As humans we know that the meaning of words depends on the context. But right now this context component is not part of our model architecture.

One token has no information about the position of other tokens (is that only about position though?)

### First, What Is Context?

First, we gonna explore a mathematical trick for an efficient implementation of the attention mechanism. Consider this one sample

In [22]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

One thing we can do right now to represent context, is to represent it as an average of previous token, ofc average is a lossy operation we lose lots of information, but it's a start.

In [23]:
# bow = bag of words, average of some words
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [24]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

### An Efficient Implementation Using Triangular Operations

Now there's a way to do that without loops, and it's to use matrix multiplications

In [25]:
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b

print('a=', a, "\n")
print('b=', b, "\n")
print('c=', c, "\n")

a= tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]) 

b= tensor([[2., 7.],
        [6., 4.],
        [6., 5.]]) 

c= tensor([[14., 16.],
        [14., 16.],
        [14., 16.]]) 



In [26]:
# tril is the lower triangular part of the matrix
torch.tril(torch.ones(3, 3))
# so this thing looks like an accumulation of things

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [27]:
# and we can use it as an identity mask

torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3,2)).float()
c = a @ b

print('a=', a, "\n")
print('b=', b, "\n")
print('c=', c, "\n")

a= tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]]) 

b= tensor([[2., 7.],
        [6., 4.],
        [6., 5.]]) 

c= tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]]) 



In [28]:
# or do averages!

torch.manual_seed(42)
a = a / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b

print('a=', a, "\n")
print('b=', b, "\n")
print('c=', c, "\n")

# c =
# [first row],
# [avg(2,6), avg(7,4)]
# [avg(2, 6, 6), avg(7, 4, 5)]

a= tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]]) 

b= tensor([[2., 7.],
        [6., 4.],
        [6., 5.]]) 

c= tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]]) 



all of this to say that we can do that for loop...

In [29]:
# version 1: slow nested loop
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

like this! And ofc it's important for our model to vectorize quickly so yeah it's not essential but we will need that for a decent model

In [30]:
# version 2: tril trick
wei = torch.tril(torch.ones(T, T))
wei = wei / torch.sum(wei, dim=1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ---> (B, T, C)
torch.allclose(xbow, xbow2) # compare that they are same

True

In [31]:
# version 3: use softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3) # compare that they are same

True

### What Problem Does Self Attention Solve? What Was People Doing Before That?

Let's forget bigram model for now and think about a real language model that generate words.

Without attention, we can predict the probability of the next word, but this generation depends entirely on the probability of what's after a word, there's no context about the sentence or the whole paragraph/article.

For example consider this example:
__'The chicken didn't across the road because it... ?'__

Without attention, without context, the model has no idea what 'it' refers to, it's just gonna generate something that usually come after it, without even considering if it's related to the chicken or the road, basically it has no idea about grammar or how sentences are constructucted.

### Recap Of Linear And Embedding Layers

In [32]:
# Embedding = index lookup
indices = torch.tensor([1, 2, 3])
W = nn.Parameter(torch.randn(65, 65))
# It takes one or more indices and returns the corresponding values
W[indices].shape


torch.Size([3, 65])

In [33]:
# Linear Layer 
input = torch.randn(32, 65)
W = nn.Parameter(torch.randn(65, 65))
# It takes one or more values and returns the corresponding values in that new space
(input @ W.t()).shape #+ B

torch.Size([32, 65])

### Coding Self-Attention

Now, that we saw how leverage the triangle trick, let's focus on how to implement self-attention.

But what is self-attention?
- It's the mechanism of baking the 'understanding' of other tokens (not only from the sentence?) into the one we are currently encoding.
- It does so by looking at other positions in the sentence and calculating the 'attention' scores for each word PAIR in a sentence, the advatange of this method is that a word very far at the end can know its dependency to a word very far at the beginning (or anywhere actually)
- Does it mean the embedding size of attention is vocab * vocab then?! It would be huge?!

In [34]:
torch.manual_seed(1337)

# version 4: self attention!

# assume we have 2 tokens in the vocab size
B,T,C = 4,8,2
x = torch.randn(B,T,C)

x = torch.randn(C)

# let's see a single Head performing self attention
head_sz = 16
key = nn.Linear(C, head_sz, bias=False) # why bias = False?
query = nn.Linear(C, head_sz, bias=False)

key, query

(Linear(in_features=2, out_features=16, bias=False),
 Linear(in_features=2, out_features=16, bias=False))

In [35]:
# basically hash the input to a key
k = key(x) # (B,T,head_sz)

# hash the input to a query
q = query(x) # (B,T,head_sz)
k.shape, q.shape

(torch.Size([16]), torch.Size([16]))

Let's do it again one level higher with one sentence

In [36]:
B,T,C = 4,8,2
x = torch.randn(T,C)

# let's see a single Head performing self attention
head_sz = 16
key = nn.Linear(C, head_sz, bias=False) # why bias = False?
query = nn.Linear(C, head_sz, bias=False)
k = key(x)
q = query(x)
wei = q @ k.T # (T,H) @ (H,T) -> (T,T)
wei.shape

torch.Size([8, 8])

Again one level higher with one batch

In [37]:
B,T,C = 4,8,2 # B = batch size, T = sequence length, C = channels
x = torch.randn(B,T,C)

# let's see a single Head performing self attention
head_sz = 16
key = nn.Linear(C, head_sz, bias=False) # why bias = False?
query = nn.Linear(C, head_sz, bias=False)
k = key(x)
q = query(x)
# wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5# (B,T,H) @ (B,H,T) -> (B,T,T)
# we need to do this version of wei otherwise we weill get super long floating values in the array somehow?! 
wei = q @ k.transpose(-2, -1)
wei.shape

torch.Size([4, 8, 8])

Time to plug that into our mathematical trick to represent previous tokens

In [38]:
tril = torch.tril(torch.ones(T, T))
# wei is not zero anymore it's the value of the dot product of the query and key (so the result of the attention?)
# wei = torch.zeros(T, T)
# wei = wei.masked_fill(tril == 0, float('-inf'))
# wei = F.softmax(wei, dim=1)
xbow3 = wei @ x
xbow3.shape

torch.Size([4, 8, 2])

In [39]:
torch.set_printoptions(sci_mode=False)
wei[0]

tensor([[ 0.0954,  0.0514, -0.0370,  0.1323,  0.1924, -0.0442, -0.3174, -0.5568],
        [ 0.1274, -0.0367, -0.0224,  0.2060,  0.4789, -0.0287, -0.7192, -1.3309],
        [-0.0565, -0.0034,  0.0150, -0.0859, -0.1710,  0.0184,  0.2639,  0.4806],
        [ 0.1111,  0.0892, -0.0506,  0.1458,  0.1621, -0.0600, -0.2871, -0.4843],
        [ 0.0322,  0.2394, -0.0695, -0.0173, -0.4031, -0.0789,  0.5157,  1.0506],
        [-0.0661, -0.0053,  0.0179, -0.1001, -0.1972,  0.0219,  0.3050,  0.5549],
        [-0.1042, -0.3516,  0.1163, -0.0621,  0.4127,  0.1334, -0.4822, -1.0401],
        [-0.1329, -0.6591,  0.2024, -0.0203,  0.9707,  0.2307, -1.2063, -2.5023]],
       grad_fn=<SelectBackward0>)

If I am the 6th node, I am not supposed to be able to see what's oin the 7th or the 8th node, only what has come before, so we use our little masking trick with the triangular operation

In [40]:
wei = wei.masked_fill(tril == 0, float('-inf'))
wei[0]

tensor([[ 0.0954,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.1274, -0.0367,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.0565, -0.0034,  0.0150,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.1111,  0.0892, -0.0506,  0.1458,    -inf,    -inf,    -inf,    -inf],
        [ 0.0322,  0.2394, -0.0695, -0.0173, -0.4031,    -inf,    -inf,    -inf],
        [-0.0661, -0.0053,  0.0179, -0.1001, -0.1972,  0.0219,    -inf,    -inf],
        [-0.1042, -0.3516,  0.1163, -0.0621,  0.4127,  0.1334, -0.4822,    -inf],
        [-0.1329, -0.6591,  0.2024, -0.0203,  0.9707,  0.2307, -1.2063, -2.5023]],
       grad_fn=<SelectBackward0>)

We don't want the negative stuff, we could do a ReLU to get rid of them, but we also want to spread the logits into probabilities, so we do a softmax.

In [41]:
wei = F.softmax(wei, dim=1)
wei[0]

tensor([[0.1368, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1412, 0.1474, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1175, 0.1524, 0.1620, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1389, 0.1672, 0.1518, 0.2331, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1284, 0.1943, 0.1489, 0.1980, 0.1185, 0.0000, 0.0000, 0.0000],
        [0.1164, 0.1521, 0.1625, 0.1822, 0.1456, 0.2985, 0.0000, 0.0000],
        [0.1120, 0.1076, 0.1793, 0.1893, 0.2679, 0.3337, 0.6735, 0.0000],
        [0.1089, 0.0791, 0.1954, 0.1974, 0.4680, 0.3678, 0.3265, 1.0000]],
       grad_fn=<SelectBackward0>)

Finally we can get the value of x, query, and key


In [42]:
value = nn.Linear(C, head_sz, bias=False)
v = value(x) # similar to key and value it's also (B,T,H)
v.shape

torch.Size([4, 8, 16])

In [43]:
out = wei @ v # (B,T,T) @ (B,T,H) -> (B,T,H)
out.shape, out

(torch.Size([4, 8, 16]),
 tensor([[[    -0.0128,      0.0353,      0.0305,     -0.0174,      0.0394,
                0.0103,     -0.0258,     -0.0273,      0.0612,      0.0775,
                0.0111,     -0.0777,     -0.0626,     -0.0206,     -0.0109,
               -0.0371],
          [    -0.0239,      0.1240,      0.0902,     -0.0840,      0.0887,
                0.0165,     -0.1040,     -0.0577,      0.1881,      0.2194,
                0.0700,     -0.2267,     -0.1564,     -0.0319,     -0.0667,
               -0.1210],
          [    -0.0152,      0.0857,      0.0613,     -0.0594,      0.0583,
                0.0102,     -0.0727,     -0.0376,      0.1284,      0.1484,
                0.0503,     -0.1538,     -0.1042,     -0.0196,     -0.0478,
               -0.0830],
          [    -0.0475,      0.1560,      0.1275,     -0.0867,      0.1529,
                0.0371,     -0.1197,     -0.1043,      0.2590,      0.3198,
                0.0624,     -0.3237,     -0.2495,     -0.0739,  

### note 1: attention as communication

### note 2: attention has no notion of space, operates over sets

### note 3: there is no communication across batch dimension

Assuming we have `B,T,C = 4,8,2`, we got 4 separate pools of 8 nodes/tokens that communicate with each other. They don't communicafte accross pools (batches).

Attention is just a set of vectors, they communicate, if you want them to have a notion of space (position,sequence) you need to add that component, which we will later by combining the attention with the positional embeddings???

> Each example across batch dimension is of course processed completely independently and never "talk" to each other

### note 4: encoder blocks vs. decoder blocks

The self attention block we did is called a decoder attention block because it has triangular masking, we prevented future tokens from communicating with past tokens. This setup is masking is usually used in autoregressive seettings, like language modelling.

But if we want to do an encoder attention block, we just have to delete the single line that does masking `wei.masked_fill(tril == 0, float('-inf'))`, allowing all tokens to communicate with each other. For example in sentiment analysis, you would want all nodes to talk to each other because you want to get the sentiment of the whole article, not just predicting the next sequence.

The beauty of attention is that it doesn't care, you can decide yourself how communication is made, from the future only, from the past only, or some crazy rule you invented.

### note 5: attention vs. self-attention vs. cross-attention

- `self-attention` means that the keys and values are produced from the sace source as queries.
- In `cross-attention`, the queries still get produced from `x` but the `keys` and `values` come from some other source (e.g. an encoder module)

### note 6: "scaled" self-attention. why divide by sqrt(head_size)

__It's an important normalization to have__

# Building the Transformer

In [44]:
k.var(), q.var(), wei.var(), torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

(tensor(0.2749, grad_fn=<VarBackward0>),
 tensor(0.3859, grad_fn=<VarBackward0>),
 tensor(0.0389, grad_fn=<VarBackward0>),
 tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872]))

In [45]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])