### Transformer from scratch

TODO: Implement encoder + decoder! (Use cross attention with the KQ, and don't mask in encoder)

In [1]:
# French to English translation example:

# <--------- ENCODE ------------------><--------------- DECODE ----------------->
# les réseaux de neurones sont géniaux! <START> neural networks are awesome!<END>

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken as tk

In [3]:
# string level instead of char level generation
with open('train.de', 'r', encoding='utf-8') as f:
  de_text = f.readlines()
with open('train.en', 'r', encoding='utf-8') as f:
  en_text = f.readlines()
  
de_text[0], en_text[0]

('Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.\n',
 'Two young, White males are outside near many bushes.\n')

In [4]:
# get tokens
enc = tk.get_encoding('gpt2')
vocab_size = enc.max_token_value + 1
print(vocab_size)

# <START> and <END> tokens
de = [torch.tensor(enc.encode('<START> ' + d.replace('\n', '<END>'))) for d in de_text]
en = [torch.tensor(enc.encode('<START> ' + e.replace('\n', '<END>'))) for e in en_text]

print(de[0])

50257
tensor([   27,  2257,  7227,    29,  1168, 42990, 10891,   469,   356,    72,
        39683,    68,   337, 11033,    77,  1008,   264,   521,   545,  4848,
         2013,   287,  4587,   399, 11033,   258,   410,  8207,   263,   347,
         9116, 15952, 29847, 10619,    29])


In [5]:
# add padding to make all de sequences the same length

# get max length
max_len = max([len(d) for d in de])
print(max_len)

# pad
de = [F.pad(d, (0, max_len - len(d))) for d in de]
print(de[0])

102
tensor([   27,  2257,  7227,    29,  1168, 42990, 10891,   469,   356,    72,
        39683,    68,   337, 11033,    77,  1008,   264,   521,   545,  4848,
         2013,   287,  4587,   399, 11033,   258,   410,  8207,   263,   347,
         9116, 15952, 29847, 10619,    29,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0])


In [6]:
enc.encode('<END>')
enc.decode([27,10619,29])

'<END>'

In [7]:
# shuffle sentences and split into train and val
import random

z = list(zip(de, en))
random.shuffle(z)

# split
n = int(0.9*len(z))
train_data = z[:n]
val_data = z[n:]

In [8]:
# example of training data
# x is (de, en) pairs
# y is en_labels, next token in the en sentence

x = (de[0], en[0]) # first sentence
y = en[0][1:]
print(x)
print(y)

for t in range(len(x[1])-1):
  context = x[1][:t+1]
  target = y[t]
  print(f"when input is {context} the target is {target}")

(tensor([   27,  2257,  7227,    29,  1168, 42990, 10891,   469,   356,    72,
        39683,    68,   337, 11033,    77,  1008,   264,   521,   545,  4848,
         2013,   287,  4587,   399, 11033,   258,   410,  8207,   263,   347,
         9116, 15952, 29847, 10619,    29,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0]), tensor([   27,  2257,  7227,    29,  4930,  1862,    11,  2635, 10835,   389,
         2354,  1474,   867, 37413, 29847, 10619,    29]))
tensor([ 2257,  7227,    29,  4930,  1862,    11,  2635, 1

In [9]:
# hyperparameters
# batch_size = 8
block_size = 8 # 64
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 1 
n_embd = 64 
n_head = 4 
n_layer = 4
dropout = 0.2

Training

Input: German sentence (prompt) and English sentence (this part needs block size and masking)
Output: English sentence (next token)

In [10]:
batch_size = 8
block_size = 8

def get_batch(split):
  # generate a small batch of data of inputs x and targets y
  
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data), (batch_size,)) # randomly select batch_size sentences
  de = torch.stack([data[i][0] for i in ix])
  en = [data[i][1] for i in ix]
  en_labels = [data[i][1][1:] for i in ix]
  
  # randomly select block_size tokens from each sentence
  # use same random indices for en and en_labels
  en_b = []
  en_labels_b = []
  for s1, s2 in zip(en, en_labels):
    i = torch.randint(len(s1)-block_size, (1,))[0]
    en_b.append(s1[i:i+block_size])
    en_labels_b.append(s2[i:i+block_size])
  
  en_b = torch.stack(en_b)
  en_labels_b = torch.stack(en_labels_b)
  
  # # create block sizes for each en sentence and concatenate
  # en_b = torch.stack([sentence[i:i+block_size] for sentence in en for i in range(0, len(sentence)-block_size, block_size)])
  # en_labels_b = torch.stack([sentence[i:i+block_size] for sentence in en_labels for i in range(0, len(sentence)-block_size, block_size)])
  
  return (de.to(device), en_b.to(device)), en_labels_b.to(device)

In [11]:
xb, yb = get_batch('train')
# print(xb[0][:5]) # xb[0] is the de sentences
# print(xb[1][:5]) # xb[1] is the en sentences, corresponding to y
# print(yb)

print('---')

for b in range(batch_size):
  for t in range(block_size):
    context = xb[1][b,:t+1]
    target = yb[b,t]
    print(f"when input is {context.tolist()} the target is {target}")

---
when input is [27] the target is 2257
when input is [27, 2257] the target is 7227
when input is [27, 2257, 7227] the target is 29
when input is [27, 2257, 7227, 29] the target is 317
when input is [27, 2257, 7227, 29, 317] the target is 38042
when input is [27, 2257, 7227, 29, 317, 38042] the target is 7720
when input is [27, 2257, 7227, 29, 317, 38042, 7720] the target is 5916
when input is [27, 2257, 7227, 29, 317, 38042, 7720, 5916] the target is 284
when input is [27] the target is 2257
when input is [27, 2257] the target is 7227
when input is [27, 2257, 7227] the target is 29
when input is [27, 2257, 7227, 29] the target is 317
when input is [27, 2257, 7227, 29, 317] the target is 1448
when input is [27, 2257, 7227, 29, 317, 1448] the target is 286
when input is [27, 2257, 7227, 29, 317, 1448, 286] the target is 30303
when input is [27, 2257, 7227, 29, 317, 1448, 286, 30303] the target is 1627
when input is [21671] the target is 257
when input is [21671, 257] the target is 651

In [31]:
class MultiHeadAttention(nn.Module):
  """ one head of self-attention """

  def __init__(self, num_heads):
      super().__init__()
      self.key = nn.Linear(n_embd, n_embd, bias=False)
      self.query = nn.Linear(n_embd, n_embd, bias=False)
      self.value = nn.Linear(n_embd, n_embd, bias=False)
      self.proj = nn.Linear(n_embd, n_embd)
      self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
      self.dropout = nn.Dropout(dropout)
      self.num_heads = num_heads

  def forward(self, q, k, v, mask=None):
      k = self.split(self.key(k)) # (B, T, C) -> (B, H, T, C/H)
      q = self.split(self.query(q))
      # compute attention scores ("affinities") and normalize by head size
      wei = q @ k.transpose(-2,-1) * (n_embd // self.num_heads) ** -0.5 # (B, H, T, C/H) @ (B, H, C/H, T) -> (B, H, T, T
      if mask:
        wei = wei.masked_fill(self.tril[:wei.shape[-2], :wei.shape[-1]] == 0, float('-inf')) # mask out future tokens
      wei = F.softmax(wei, dim=-1) # (B, H, T, T)
      wei = self.dropout(wei)
      # perform the weighted aggregation of the values
      v = self.split(self.value(v))
      # use q below because q will always have the same shape as x
      out = (wei @ v).transpose(1,2).contiguous().view(q.shape[0], q.shape[2], -1) # (B, H, T, T) @ (B, H, T, C/H) -> (B, T, H, C/H) -> (B, T, C)
      return out

  def split(self, x): # split the last dimension into num_heads
      B,T,C = x.shape
      return x.view(B, T, self.num_heads, C // self.num_heads).transpose(1,2)

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [32]:
class EncoderBlock(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        # head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = self.ln1(x)
        x = x + self.sa(q=x, k=x, v=x, mask=None) # residual skip connections +
        x = self.ln2(x)
        x = x + self.ffwd(x)
        return x
      
class DecoderBlock(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        # head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head)
        self.eda = MultiHeadAttention(n_head)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ln3 = nn.LayerNorm(n_embd)

    def forward(self, x, enc):
        x = self.ln1(x)
        x = x + self.sa(q=x, k=x, v=x, mask=True)
        # encoder-decoder attention
        x = self.ln2(x)
        x = x + self.eda(q=x, k=enc, v=enc, mask=None)
        x = self.ln3(x)
        x = x + self.ffwd(x)
        return (x, enc) # need to include enc for sequential (next block)

In [33]:
class Encoder(nn.Module):
    """ Transformer encoder: a stack of Transformer blocks """

    def __init__(self, n_embd, n_head, n_layers):
        super().__init__()
        self.blocks = nn.ModuleList([EncoderBlock(n_embd, n_head) for _ in range(n_layers)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

class Decoder(nn.Module):
    """ Transformer decoder: a stack of Transformer blocks """

    def __init__(self, n_embd, n_head, n_layers):
        super().__init__()
        self.blocks = nn.ModuleList([DecoderBlock(n_embd, n_head) for _ in range(n_layers)])

    def forward(self, x, enc):
        for block in self.blocks:
            x, enc = block(x, enc)
        return x, enc

In [34]:
class Transformer(nn.Module):
    """ Transformer model: encoder + decoder """

    def __init__(self, n_embd, max_len, n_head, n_layer):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table_inpt = nn.Embedding(block_size, n_embd)
        self.position_embedding_table_prompt = nn.Embedding(max_len, n_embd)
        self.enc = Encoder(n_embd, n_head, n_layer)
        self.dec = Decoder(n_embd, n_head, n_layer)
        # self.enc = nn.Sequential(*[EncoderBlock(n_embd, n_head=n_head) for _ in range(n_layer)])
        # self.dec = nn.Sequential(*[DecoderBlock(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, x, y=None):
        prompt, inpt = x
        prompt = self.token_embedding_table(prompt) + self.position_embedding_table_prompt(torch.arange(prompt.shape[1], device=prompt.device))
        inpt = self.token_embedding_table(inpt) + self.position_embedding_table_inpt(torch.arange(inpt.shape[1], device=inpt.device))
        enc_out = self.enc(prompt)
        dec_out, _ = self.dec(inpt, enc_out)
        logits = self.lm_head(self.ln_f(dec_out))
        
        if y is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
        
        return logits, loss

In [35]:
model = Transformer(n_embd, max_len, n_head, n_layer).to(device)
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

6.954961 M parameters


In [36]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print(device)

cpu


In [37]:
# for iter in range(max_iters):
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
print(loss.item())

11.018996238708496


In [39]:
for iter in range(500): # max_iters
  xb, yb = get_batch('train')
  logits, loss = model(xb, yb)
  
  if iter % 100 == 0:
    print(loss.item())
  
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

11.040460586547852
8.5200777053833


Testing

In [39]:
# try with just one sample

# randomly get a sample
idx = random.randint(0, len(train_data))
print(idx)

de_test = train_data[idx][0]
en_test = train_data[idx][1]
en_label_test = train_data[idx][1][1:]

print(de_test, en_test, en_label_test)

14998
tensor([   27,  2257,  7227,    29,   412,   259, 20291,  1976,   494,  4352,
          304, 42326,   370, 11286,    11,  4587, 10255,   412,   320,  1142,
           11, 24884,   354,    76,  8158,  3318, 30837,   268,   410,   692,
          894, 40780,   318,    83, 29847, 10619,    29,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0]) tensor([   27,  2257,  7227,    29,  1869, 10427,   257,  6383,  1336,   286,
        38674,    11,   285,  2840,   290,  1379,  3150, 29847, 10619,    29]) tensor([ 2257,  7227,    29,  1869

In [40]:
de_test.shape

torch.Size([102])

In [41]:
torch.arange(de_test.shape[0])

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101])

In [42]:
# positional embedding size should actually be max_sentence_len + 1
# for now just use the length of the sentence

token_embedding_table = nn.Embedding(vocab_size, n_embd)
position_embedding_table = nn.Embedding(len(de_test), n_embd)

In [43]:
emb_test = token_embedding_table(de_test)
emb_test2 = position_embedding_table(torch.arange(de_test.shape[0])) # what position each word is in the sentence
emb_test3 = emb_test + emb_test2

emb_test3

tensor([[ 0.3093,  0.7866, -0.5895,  ..., -1.3089,  0.2187,  0.7966],
        [-3.4576, -0.7954, -2.3464,  ...,  1.1559,  0.0368,  0.0379],
        [-2.5936,  1.0514, -0.8126,  ...,  3.7348,  1.2958, -2.9335],
        ...,
        [-2.7214, -1.6235,  0.1155,  ..., -2.2392, -1.1426, -0.8550],
        [-0.7753, -1.3761,  0.8096,  ..., -2.9444, -0.5745, -1.4609],
        [-3.3628, -0.5707,  2.4626,  ...,  0.4927, -0.8976,  0.6312]],
       grad_fn=<AddBackward0>)

In [44]:
emb_test.shape # makes sense because we are getting n_embd embedding for each token

torch.Size([102, 64])

In [45]:
emb_test2.shape, emb_test3.shape

(torch.Size([102, 64]), torch.Size([102, 64]))

Yayy encoding!

In [46]:
emb_test3.unsqueeze(0).shape

torch.Size([1, 102, 64])

In [47]:
# now that we have an embedding, we can pass it through the encoder

enc = EncoderBlock(n_embd, n_head)
# add a dimension to the embedding because the encoder expects a batch of sentences
enc_out_test = enc(emb_test3.unsqueeze(0))
enc_out_test

tensor([[[ 0.2188,  0.7534, -0.7414,  ..., -0.7603,  0.1801,  0.0924],
         [-2.1068, -0.1372, -1.9280,  ...,  0.9991,  0.2325, -0.5434],
         [-1.3446,  1.0218, -0.8601,  ...,  2.7342,  1.1625, -2.0194],
         ...,
         [-1.7012, -0.7692, -0.1831,  ..., -1.1328, -0.8862, -0.5705],
         [-0.2505, -0.5427,  0.4150,  ..., -1.1798, -0.7398, -0.8358],
         [-2.2766, -0.0478,  1.1928,  ...,  0.8402, -0.2884, -0.2580]]],
       grad_fn=<AddBackward0>)

Pass through decoder now

In [48]:
# for now just crop en_test to block_size (actual transformer would use a sliding window)
en_test = en_test[:block_size]
en_label_test = en_label_test[:block_size]
en_test, en_label_test

(tensor([   27,  2257,  7227,    29,  1869, 10427,   257,  6383]),
 tensor([ 2257,  7227,    29,  1869, 10427,   257,  6383,  1336]))

In [49]:
position_embedding_table2 = nn.Embedding(block_size, n_embd)

In [50]:
en_emb1 = token_embedding_table(en_test)
en_emb2 = position_embedding_table2(torch.arange(en_test.shape[0]))
en_emb3 = en_emb1 + en_emb2

In [51]:
en_emb3.unsqueeze(0).shape, enc_out_test.shape

(torch.Size([1, 8, 64]), torch.Size([1, 102, 64]))

In [52]:
dec = DecoderBlock(n_embd, n_head)
dec_out_test = dec(en_emb3.unsqueeze(0), enc_out_test)
dec_out_test

tensor([[[ 4.2742e-02, -6.6307e-01,  9.5389e-01,  6.4754e-01,  5.4297e-01,
           1.1434e+00,  1.5338e+00, -2.6165e-01,  8.4796e-01,  5.3543e-02,
          -2.2576e-01,  7.0200e-01, -3.3272e+00, -3.9943e-01, -2.0145e+00,
           7.4581e-02, -1.2408e+00,  2.4673e-01, -5.0055e-01,  6.7559e-01,
          -1.3922e+00,  2.2750e-01,  5.7054e-01,  1.0686e+00,  1.1364e+00,
           1.5668e+00, -2.2240e-01, -1.6714e+00,  7.8573e-01,  1.4196e+00,
           4.4317e-01, -1.5133e+00,  2.7186e-01, -1.1186e+00, -2.1555e+00,
           1.5989e-01,  1.0834e+00,  1.3989e+00,  8.4593e-01,  1.0235e+00,
           1.8332e-01,  8.6455e-01,  1.0189e-01, -8.9185e-01, -3.6967e-01,
          -1.8523e+00, -1.3451e+00, -3.0293e-01, -2.2932e-01,  6.4974e-01,
           8.3090e-01,  3.7256e-01, -1.7152e+00,  1.1702e+00, -8.9748e-01,
           1.9070e-01,  1.3602e+00,  6.9795e-01, -1.0332e+00,  3.2601e-01,
          -1.0430e-01, -2.9414e-01,  5.1320e-01,  8.8965e-01],
         [-1.7029e+00,  8.7928e-01, -

In [53]:
dec_out_test.shape # logits are (B,T,vocab_size); in this case, (1, block_size, vocab_size)

torch.Size([1, 8, 64])

In [38]:
model.dec

Decoder(
  (blocks): ModuleList(
    (0): DecoderBlock(
      (sa): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=False)
        (query): Linear(in_features=64, out_features=64, bias=False)
        (value): Linear(in_features=64, out_features=64, bias=False)
        (proj): Linear(in_features=64, out_features=64, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (eda): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=False)
        (query): Linear(in_features=64, out_features=64, bias=False)
        (value): Linear(in_features=64, out_features=64, bias=False)
        (proj): Linear(in_features=64, out_features=64, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=64, out_features=256, bias=True)
          (1): ReLU()
          (2): Linear(in_features=256, out_features=64, bias=True)
          (3

Get logits and evaluate loss (part of Transformer class)

In [54]:
logits = nn.Linear(n_embd, vocab_size)(dec_out_test)
targets = en_label_test
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1)) # some reshaping so that there is only 2 dims
loss

tensor(11.3545, grad_fn=<NllLossBackward0>)

In [55]:
logits.shape, targets.shape
# makes sense because we are getting probabilities for each token in vocab being the next token and computing loss

(torch.Size([1, 8, 50257]), torch.Size([8]))

In [56]:
# index into the highest probability token - these are the predictions!
print(logits.argmax(dim=-1))
print(targets)

tensor([[29663, 16217, 10084, 23310, 46877, 11894, 14868, 27781]])
tensor([ 2257,  7227,    29,  1869, 10427,   257,  6383,  1336])


Translate test (sample from model)

In [44]:
# what we feed into the model is german sentence and <START> token

example_de = "ein boot mit mehreren männern darauf wird von einem großen pferdegespann ans ufer gezogen ."
prompt = torch.tensor(enc.encode('<START> ' + example_de + '<END>')).unsqueeze(0)
context = torch.tensor(enc.encode('<START>')).unsqueeze(0)

def generate(model, prompt, context, max_len=100):
  with torch.no_grad():
    for i in range(max_len):
      logits, _ = model((prompt, context))
      logits = logits[:, -1, :]
      probs = F.softmax(logits, dim=-1)
      next_token = torch.multinomial(probs, num_samples=1)
      context = torch.cat((context.squeeze(0), next_token.squeeze(0)), dim=0).unsqueeze(0)
      if next_token.item() == enc.encode('<END>'):
        break
  return context

print(generate(model, prompt, context))

IndexError: index out of range in self

In [41]:
context

tensor([[  27, 2257, 7227,   29]])

In [43]:
logits[:, -1, :].argmax(dim=-1)

tensor([   29,   257,   287,   257,   287,    29,   287, 29847])