In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [7]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
cc = ''.join(chars)
print(f'{cc} {vocab_size}')

# Tokenization (character-level)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 65


In [15]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

n = int(0.9*len(data)) # 90-10 train-test split
train_data = data[:n]
val_data = data[n:]

torch.Size([1115393]) torch.int64


In [20]:
torch.manual_seed(1337)
block_size = 8
batch_size = 4

def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data)-block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [27]:
class BigramLanguageModel(nn.Module):
    """ When referring to B,T,C:
        B: number of batches
        T: "time component" Basically the required "context" to make a prediction
        C: number of channels. In our case, vocab_size
    """
    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)
        
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C) # reshape (B,T,C) to (B*T,C)
            targets = targets.view(B*T)  # reshape (B,T) to (B*T,)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :]
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
m = BigramLanguageModel(vocab_size)
out, loss = m(xb, yb)
print(out.shape)
# this will be garbage
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])

EGXxuxFFW:BkQW bnMoNi&zAyrONl?3XHzQmSBr&XxUnfeyI$aCSZRt:WI,tIGxKuGbOX;K-oRnM.VRFKORh.JCvATdTMv!ZPdiA


In [35]:
# Train the Bigram model
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

batch_size = 32
for steps in range(10000):
    xb, yb = get_batch('train')
    
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

2.366422653198242


In [43]:
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=400)[0].tolist()))



Durd I Jut mid, mubais venearkeses incaforan opug.
BEnd re:
BUTo weror,
S: hovotet
I:
Lagn, ban fan t vigow f fr in d-sansishe, hermesunatucan yowin IUMP: ckes hemou windu r ors th, useneam, thas TMIsurie t y aithaipemaieatioreyor save
G l, analvecirau he, h puthasto hine! O:
Sarorengot nrndigoffatatha I fr brd?
Thy to pst ndan hilis s ch
BLUnowave t, CAng pr pr pe! y hou ts lage w gr mi, ve hand


We want to average past context for every token, so for each one we only consider "past" tokens

Version 1: for loop

In [44]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [58]:
xbow = torch.zeros(B,T,C)
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

Version 2: matrix multiply

The zeroes in the "tril" matrix allow us to ignore everything from the current token on. Hence, we average everything *up to* that current token.

In [59]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b
print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [61]:
wei = torch.tril(torch.ones(T,T)) # wei as for "Weights"
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B,T,T) @ (B,T,C) ---> (B,T,C)

torch.allclose(xbow, xbow2) # it's the same result but made more efficiently

True

Version 3: adding Softmax

Softmax is a normalization operation. In order to ignore elements with it, we can't have zeroes, since e^0 = 1. Thus, we make them -inf, not 0. And e^-inf = 1 / inf = 0.

In [64]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros(T,T)
wei = wei.masked_fill(tril==0, float('-inf')) # for each element where tril is 0, make it -inf
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow2, xbow3)  # it's STILL the same result !!

True

## Version 4: Self-Attention

The key thing to know here is that not each token in a token's past has to be accounted in the same way. For example, if I'm a vowel, I might want to look up for consonants in my past.

To do that, every token issues 2 vectors:

+ Query vector: "what I'm looking for?"

+ Key vector: "what do I contain"

The way we get "affinities" between these tokens is to do a "dot product" between the Keys and the Queries. So for a particular token, it does "my Query *dot product* all the keys from my previous tokens". If the key and the query interact, they will produce a "big number", so "the things I (***Self***) am looking for" will get more ***Attention***.

In [69]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

# Let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B,T,16)
q = query(x) # (B,T,16)
# The Dot product between Queries and Keys
wei = q @ k.transpose(-2, -1) # (B,T,16) @ (B,16,T) ---> (B,T,T)

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [70]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Notes:

+ Attention is a **communication mecanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights. For example, in a graph of 8 nodes, N1 looks only at itself, N2 looks at itself *and* at N1... and N8 looks at itself and to *everyone else*.

+ There is no notion of space. Attention simply acts over a set of vectors. This is why we need to *positionally encode* tokens.

+ Each exampel across batch dimension is of course processed completely independently and never "talk" to each other

+ In an "encoder" attention block we used the `tril` part to ignore everything from a token onward, since we only care about its "past". But we could well get rid of that section and take everything into account, both past, present and future.

+ "self-attention" just means that the keys and values are produced from the same source (`x`) as queries. In "cross-attention" the keys and values can come from a separate source of nodes we pull information from.