# Transformer
following tutorial in: https://www.youtube.com/watch?v=kCc8FmEb1nY

In [1]:
# Get Shakespeare database
with open('data/tiny-shakespeare.txt') as f:
    text = f.read()

In [2]:
print('Length of dataset in chars: ', len(text))

Length of dataset in chars:  1115394


In [3]:
# first 200 chars
print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [4]:
# find unique chars
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [9]:
# Encode chars to integers and vice versa
c2i = {c: i for i, c in enumerate(chars)}
i2c = {i: c for i, c in enumerate(chars)}

# encode - take a string and encode in integers
encode = lambda string: [c2i[c] for c in string]

# decode - take a list of integers and produce text
decode = lambda list: ''.join([i2c[i] for i in list])

print(encode('hi there'))
print(decode(encode('hi there')))

[46, 47, 1, 58, 46, 43, 56, 43]
hi there


In [10]:
# encode as a tensor
import torch
data = torch.tensor(encode(text), dtype=torch.long)

In [11]:
# split train and validation sets
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [12]:
# max context length
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [13]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'Input: {context}, target: {target}')

Input: tensor([18]), target: 47
Input: tensor([18, 47]), target: 56
Input: tensor([18, 47, 56]), target: 57
Input: tensor([18, 47, 56, 57]), target: 58
Input: tensor([18, 47, 56, 57, 58]), target: 1
Input: tensor([18, 47, 56, 57, 58,  1]), target: 15
Input: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
Input: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58


In [15]:
# batches
batch_size = 4
block_size = 8 # time

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack( [data[i:i+block_size] for i in ix] )
    y = torch.stack( [data[i+1:i+block_size+1] for i in ix] )
    return x, y

xb, yb = get_batch('train')
for b in range(batch_size):     # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'Input: {context}, target: {target}')

Input: tensor([53]), target: 58
Input: tensor([53, 58]), target: 46
Input: tensor([53, 58, 46]), target: 47
Input: tensor([53, 58, 46, 47]), target: 52
Input: tensor([53, 58, 46, 47, 52]), target: 45
Input: tensor([53, 58, 46, 47, 52, 45]), target: 1
Input: tensor([53, 58, 46, 47, 52, 45,  1]), target: 39
Input: tensor([53, 58, 46, 47, 52, 45,  1, 39]), target: 50
Input: tensor([57]), target: 1
Input: tensor([57,  1]), target: 43
Input: tensor([57,  1, 43]), target: 44
Input: tensor([57,  1, 43, 44]), target: 44
Input: tensor([57,  1, 43, 44, 44]), target: 43
Input: tensor([57,  1, 43, 44, 44, 43]), target: 41
Input: tensor([57,  1, 43, 44, 44, 43, 41]), target: 58
Input: tensor([57,  1, 43, 44, 44, 43, 41, 58]), target: 10
Input: tensor([61]), target: 46
Input: tensor([61, 46]), target: 43
Input: tensor([61, 46, 43]), target: 52
Input: tensor([61, 46, 43, 52]), target: 1
Input: tensor([61, 46, 43, 52,  1]), target: 51
Input: tensor([61, 46, 43, 52,  1, 51]), target: 63
Input: tensor([

## Bigram Language Model

In [21]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1000)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # read logits for the next token from the lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor
        logits = self.token_embedding_table(idx) # (B,T,C), C is channel or vocab size
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # channel is expected as second
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices
        for _ in range(max_new_tokens):
            # get predictions
            logits, _ = self(idx)
            # work only from the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append new data to time dimension
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

m = BigramLanguageModel(vocab_size)
out, loss = m(xb, yb)
print(out.shape)
print(loss)

idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx=idx, max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(4.9138, grad_fn=<NllLossBackward0>)

ljLZqqVxAwLGPFvPj&IbaEeJH?OF-C3jSUy'Aw$nIyxBxOSYkgS;mhUf!TFV&AfBO'jAUhxRJ-bM
Ze$fLSWaKuRSG;c?KUBSE!V


In [22]:
# create a pytorch optimizer (typical lr is 1e-4, for smaller models can bigger)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [25]:
batch_size = 32
for steps in range(10000):
    # sample batch data
    xb, yb = get_batch('train')

    # eval the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.4942922592163086


In [27]:
idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx=idx, max_new_tokens=500)[0].tolist()))


ge sth l
NBens me,
G ok?

Corn casirith
Finore fengs, as, leand-fe cofthegrap,-my.
PA:
Be fr's Vove m,-manouga TCANDIOFLUSOOXur ETathe theore hant h cthag y st tse ourat ber soutothayoris anedkat, d f ders galwngathrogupr mindem aches:

NIZWhig. rut:

A bendat y o sererseayos cknd,
tharofl murf otind, mat:
II: lisieay oudeeged
OLAnde,
Fotul pour.
Ber, wegig hefou!
CAs nvis, g usath h suris

KENTd ak, osagedar ss,-m ss w.


Boutore g'lewhenoulle ar lem r ucosendus fu

RYoy my bea and t'd,
As: s s


## Self Attention

In [28]:
torch.manual_seed(1000)
B, T, C = 4, 8, 2 # batch, time, channel
x = torch.rand(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [33]:
# a simple form of attention would be to average across past and current token
# we want x[b, t] = mean_{i<=t} x[b, i]
# averaging is analogous to bag of words
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]
        xbow[b, t] = torch.mean(xprev, 0)

In [34]:
x[0]

tensor([[0.3189, 0.6136],
        [0.4418, 0.2580],
        [0.2724, 0.6261],
        [0.4410, 0.3653],
        [0.3535, 0.5971],
        [0.3572, 0.4807],
        [0.4217, 0.1254],
        [0.6818, 0.0571]])

In [35]:
xbow[0]

tensor([[0.3189, 0.6136],
        [0.3804, 0.4358],
        [0.3444, 0.4992],
        [0.3685, 0.4657],
        [0.3655, 0.4920],
        [0.3641, 0.4901],
        [0.3723, 0.4380],
        [0.4110, 0.3904]])

In [36]:
# use matrix mult for efficiency
# e.g. 2nd row of tril averages first two elements
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(a)
print(b)
print(c)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [37]:
# average using a matrix
wei = torch.tril(torch.ones(T, T))
wei = wei / torch.sum(wei, dim=1, keepdim=True)
xbow2 = wei @ x # (T, T) @ (B, T, C) --> (B, T, C), because of broadcasting

In [38]:
torch.allclose(xbow, xbow2)

True

In [40]:
# averaging using softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T)) # weights can be looked as strength of connection
wei = wei.masked_fill(tril==0, float('-inf')) # override those in the past, which cannot be used
wei = F.softmax(wei, dim=1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

## Self Attention for a Single Head

In [43]:
torch.manual_seed(1000)
B, T, C = 4, 8, 32 # batch, time, channel
x = torch.randn(B, T, C)

# single head attention
head_size = 16
key = nn.Linear(C, head_size, bias=False) # just matrix multiply
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, head_size)
q = query(x)
# communication happens in scalar product
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)
wei = wei * head_size**-0.5 # this is to normalise variance, so that weights are flat at init

tril = torch.tril(torch.ones(T, T))
# weight are now calculated through attention
wei = wei.masked_fill(tril==0, float('-inf')) # override those in the past, which cannot be used
wei = F.softmax(wei, dim=1)

v = value(x)
out = wei @ v

wei[0]

tensor([[9.9233e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [7.5979e-03, 2.0696e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [4.7314e-03, 5.1011e-02, 1.7596e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [2.9592e-02, 8.5036e-01, 6.1478e-01, 4.7717e-03, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [4.0580e-01, 1.4333e-02, 1.3564e-02, 7.5291e-01, 8.0826e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [7.3173e-03, 1.4970e-02, 3.6143e-02, 3.4910e-04, 2.9210e-02, 1.6003e-02,
         0.0000e+00, 0.0000e+00],
        [1.2455e-01, 5.5340e-03, 2.8228e-01, 1.0573e-01, 4.2355e-02, 4.4684e-01,
         1.7650e-01, 0.0000e+00],
        [3.2118e-01, 6.1727e-02, 3.5638e-02, 1.3624e-01, 1.2018e-01, 5.3715e-01,
         8.2350e-01, 1.0000e+00]], grad_fn=<SelectBackward0>)