# Self attention mechanism. Transformer architecture

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
with open('data/tiny_shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
len(text)

In [None]:
print(text[:1000])

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)

In [None]:
itos = {idx: v for idx, v in enumerate(chars)}
stoi = {v: k for k, v in itos.items()}
encode = lambda s:[stoi[c] for c in s]
decode = lambda l: ''.join((itos[i] for i in l))
print(encode('hii there'))
print(decode(encode('hii there')))

In [None]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data = torch.tensor(encode(text), dtype=torch.long, device=device)
print(data.shape, data.dtype)
print(data[:1000])

In [None]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[:n]

In [None]:
block_size = 8
train_data[:block_size+1]

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(context, '-->', target)

In [None]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)
print('-' * 5)

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print('if', context.tolist(), '--->', target)

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size) -> None:
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size, device=device)

    def forward(self, idx, targets):
        logits = self.token_embedding_table(idx) # B, T, C
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(-1)
        loss = F.cross_entropy(logits, targets)
        return logits, loss
    

m = BigramLanguageModel(vocab_size)
out, loss = m(xb, yb)
print(out.shape, loss)


In [None]:
# expected loss
-np.log(1/vocab_size)

In [None]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size) -> None:
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size, device=device)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx) # B, T, C 
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
xb.shape, logits.shape, logits

In [None]:
idx = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(m.generate(idx, max_new_tokens=100)[0].cpu().tolist()))

In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

batch_size = 32
total_epochs = 10000

for step in range(total_epochs):

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()

    if step % (total_epochs // 10) == 0:
        print(loss.item())

In [None]:
idx = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(m.generate(idx, max_new_tokens=500)[0].cpu().tolist()))

# Self attention

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

In [None]:
# we want x[b, t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]
        xbow[b, t] = torch.mean(xprev, 0)

display(xbow.shape)
display(x[0])
display(xbow[0])

In [None]:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
wei

In [None]:
xbow2 = wei @ x # (T,T) @ (B, T, C) -> (B, T, C)
torch.allclose(xbow, xbow2)

In [None]:
# A different way for the same result
wei = torch.zeros((T, T))
tril = torch.tril(torch.ones((T, T)))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei

In [None]:
wei = F.softmax(wei, dim=-1)
wei

In [None]:
# putting all together
wei = torch.zeros((T, T))
tril = torch.tril(torch.ones((T, T)))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
(wei @ x).shape

This idea is interesting because you can think of wei as an affinity between each character with the previous charactarters. In this case, affinities are uniformly distributed.

The -inf in the upper triangle of the weighted matrix forces the communication to be from the past and not from the future characters

The next step is change the uniform distribution of affinities by one which is data-driven.

In [None]:
torch.manual_seed(133)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# single head self-attention!
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

k = key(x)   # B, T, head_size
q = query(x) # B, T, head_size
wei = q @ k.transpose(-2, -1)   # (B,T,head_size) @ (B,head_size,T) -> B,T,T

# wei = torch.zeros((T, T))
tril = torch.tril(torch.ones((T, T)))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

out = wei @ x
out.shape

In [None]:
torch.trunc(wei[0]*1000) / 1000

In [None]:
# Lets also codify the value of x

torch.manual_seed(133)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# single head self-attention!
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)   # B, T, head_size
q = query(x) # B, T, head_size
wei = q @ k.transpose(-2, -1)   # (B,T,head_size) @ (B,head_size,T) -> B,T,T

# wei = torch.zeros((T, T))
tril = torch.tril(torch.ones((T, T)))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape

As a resume of self attention, for each head:
- Each character contribute to the head according to its value
- Each character has characteristics encoded in its key
- Each character is interested in some characteristics encoded in its query

Then:
- The key and the query are combined, returning an interesting matrix
- This information is combined with the value encoded to obtain the attention-powered output



In [None]:
# In this code, the position of each element is not used in the process, 
# - the order of the characters is not taken into acount
n_embd = 32
wpe = nn.Embedding(block_size, n_embd)

torch.manual_seed(133)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

pos = torch.arange(0, T, dtype=torch.long).unsqueeze(0) # shape (1, t)
pos_emb = wpe(pos)
pos_emb.shape

In [None]:
(pos_emb+x).shape

Notes:
- Attention is a communicaion mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information as a weighted sum from all the nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens
- Each example across batch dimension is of course processed completely independently and never 'talk' to each other
- In an 'encoder' attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a 'decoder' attention block because it has triangular masking, and is usually used in autoregressive settings, like language modelling
- 'self-attention' just means that the keys and values are produced from the same source as queries. In 'cross-attention', the queries still get produced from x, but the keys and values come from other, external source (e.g. an encoder module).
- 'Scaled' attention additional divides wei by 1/sqrt(head_size). This makes it so when input Q, K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much.

[Attention is all you need (PDF)](papers/NIPS-2017-attention-is-all-you-need-Paper.pdf)

More details and implementation [here](https://pub.towardsai.net/build-your-own-large-language-model-llm-from-scratch-using-pytorch-9e9945c24858)