# Building a Transformer from Scratch

In [20]:
import torch,math
import torch.nn as nn
import numpy as np
from collections import Counter
import json
import os
import urllib.request

## Load Data

In [2]:
# download data from online
if not os.path.exists("the-verdict.txt"):
    url = ("https://raw.githubusercontent.com/rasbt/"
           "LLMs-from-scratch/main/ch02/01_main-chapter-code/"
           "the-verdict.txt")
    file_path = "the-verdict.txt"
    urllib.request.urlretrieve(url, file_path)

In [3]:
# load in data
with open("the-verdict.txt", "r", encoding="utf-8") as f:
    text = f.read()

## Tokenizer

We use Byte Pair Encoding as the tokenizer for our purpose. This encompasses the following steps

1. We initialize the vocabulary with the character vocabulary where each word is represented as a sequence of characters. 
2. The symbol pairs are iteratively counted and replace the most frequent pair say ('A','B') with 'AB' which is a new symbol.
3. Every merge operation introduces a new symbol which represents a character n-gram. 
4. Frequent character n-grams (or whole words) are eventually merged into a single symbol, thus BPE requires no shortlist.
5. The final symbol vocabulary size is equal to the size of the initial vocabulary, plus the number of merge operations– the latter is the only hyperparameter of the algorithm.

For efficiency, we do not consider pairs that cross word boundaries.

In [4]:
class BytePairEncodingTokenizer:
  def __init__(self,word_split = 'Ġ'):
    self.vocab={}
    self.inverse_vocab={}
    self.tokens=[]
    self.bpe_pairs={}
    self.word_split = word_split

  def set_vocabulary(self, word_split, special_tokens):
    # Step 1: Setting the unique characters
    unique_characters=[chr(i) for i in range(256)]

    # Step 2: Word split characters
    if word_split not in unique_characters:
      unique_characters.append(self.word_split)

    # Step 3: Special tokens
    if special_tokens:
      unique_characters.extend(self.special_tokens)

    return unique_characters

  def get_max_freq_pair(self, tokens):
    # pair_counts={}
    pairs=[]
    # Step 1: Gettkng all the token pairs-> (token[i],token[i+1])
    for index in range(len(tokens)-1):
      pairs.append((tokens[index],tokens[index+1]))

    # Step 2: Getting the count of occurences of each of token pairs
    pairs_counts=Counter(pairs)

    # Step 3: Get the token pair whose count is the highest
    max_pair=max(pairs_counts.items(),key=lambda x: x[1])[0]
    return max_pair

  def merge_tokens(self, tokens, max_pair, new_pair_id):
    # In the tokens, check the presence of occurence of max_pair and if exists then replace max_pair with new_pair_id
    # Eg: tokens=[87,76,44,25,38,44,25,19], max_pair=[44,25],  new_pair_id=123
    # Output: [87,76,123,38,123,19]
    new_tokens=[]
    i=0
    while i<=len(tokens)-1:
      if i==len(tokens)-1:
        new_tokens.append(tokens[i])
        break

      elif (tokens[i],tokens[i+1])==max_pair:
        new_tokens.append(new_pair_id)
        i+=2

      else:
        new_tokens.append(tokens[i])
        i+=1
    return new_tokens

  def train(self, text, vocab_size, special_tokens):
    if vocab_size <= 258:
      raise ValueError('Please enter a vocab size greater than 258 since this defines the basic set of characters')
    self.special_tokens = special_tokens

    # Setting the vocabulary
    vocab = self.set_vocabulary(self.word_split, self.special_tokens)
    for index,character in enumerate(vocab):
      self.vocab[index]=character
      self.inverse_vocab[character]=index

    # Transforming the text
    ## Step 1: Replacing all thw white-space character
    processed_text=[]
    for index,char in enumerate(text):
      if index != 0 and char == ' ':
        processed_text.append(self.word_split)
      if char != ' ':
        processed_text.append(char)
    processed_text="".join(processed_text)

    ## Step 2: Getting the numerical form of token
    self.tokens = []
    for char in processed_text:
      self.tokens.append(self.inverse_vocab[char])

    ## Step 3: BPE-algorithm
    vocab_length = len(self.vocab)
    for i in range(vocab_length,vocab_size):
      max_pair=self.get_max_freq_pair(self.tokens)
      if max_pair is None:
        break
      self.bpe_pairs[max_pair]=i
      self.tokens=self.merge_tokens(self.tokens, max_pair, i)

    ## Step 4: Update vocab with BPE
    for pair,new_index in self.bpe_pairs.items():
      merged_token=self.vocab[pair[0]]+self.vocab[pair[1]]
      self.vocab[new_index]=merged_token
      self.inverse_vocab[merged_token]=new_index

  def encode(self, text):
    # Step 1: Basically tokens are split into words. Replace all the occurences of "\n" to " <NEWLINE> ". This is to avoid splitting issues.
    tokens_split=text.replace('\n',' <NEWLINE> ').split()
    tokens=[]
    for i in tokens_split:
      if i=='<NEWLINE>':
        tokens.append('\n')
      else:
        tokens.append(i)

    # Step 2: Cleaning of tokens
    ## Eg: 'This is a ball' will be tokenized as ['The','Ġis','Ġa', 'Ġball']
    # Ensures that all the tokens in a line other than the first one will be prefixed with "Ġ" to show the word boundaries
    tokens_cleaned=[]
    for index,token in enumerate(tokens):
      if index>0 and not token.startswith('\n'):
        tokens_cleaned.append(self.word_split+token)
      else:
        tokens_cleaned.append(token)

    # Step 3: Getting the corresponding token IDs from the cleaned tokens
    ## Checks whether tokens exist in the vocabulary. If not, then perform BPE tokenization of the token
    token_ids=[]
    for token in tokens_cleaned:
      if token in self.inverse_vocab.keys():
        token_ids.append(self.inverse_vocab[token])
      else:
        token_ids.extend(self.tokenize_using_bpe(token))
    return token_ids

  def tokenize_using_bpe(self, token):
    # Step 1: Mapping the tokens to their IDs from the vocabulary
    token_ids=[]
    for char in token:
      if char in self.inverse_vocab.keys():
        token_ids.append(self.inverse_vocab[char])
      else:
        token_ids.append(None)

    # Step 2: Check whether token does not exist in Vocabulary- In that case stop
    if None in token_ids:
      token_dict=dict(zip(token_ids,token))
      missing_characters=[]
      for id,ch in token_dict.items():
        if id is None:
          missing_characters.append(ch)
      raise ValueError(f"No token IDs found for the characters:{missing_characters}")

    # Step 3: Now merging
    can_merge=True
    while can_merge and len(token_ids)>1:
      can_merge=False
      i=0
      new_tokens=[]
      """
      Check whether the token pair is part of bpe_pairs occured during training,
      If yes, index = index + 2, else index = index + 1.
      This iteration occurs until there exists no merging exists for all the tokens in token_ids.
      No merging exists means that there are no more possible keys to merge in bpe_pairs.
      """
      while i<len(token_ids)-1:
        pair=(token_ids[i],token_ids[i+1])
        if pair in self.bpe_pairs.keys():
          pair_id=self.bpe_pairs[pair]
          new_tokens.append(pair_id)
          i+=2
          can_merge=True
        else:
          new_tokens.append(token_ids[i])
          i+=1
      if i<len(token_ids):
        new_tokens.append(token_ids[i])
      token_ids=new_tokens

    return token_ids

  def decode(self, token_ids):
    # Step 1: Check whether there are non-existing token IDs
    non_existing_ids=[]
    for id in token_ids:
      if id not in self.vocab.keys():
        non_existing_ids.append(id)
    if len(non_existing_ids)>0:
      raise ValueError(f"No token found for the token IDs:{non_existing_ids}")

    # Step 2: Decoding- Check whether text corresponding to token ID starts with word_split-symbol('Ġ'). If yes replace word_split-symbol with " " else just append the text to string
    final=""
    for id in token_ids:
      text=self.vocab[id]
      if text.startswith(self.word_split):
        final+=" "+text[1:]
      else:
        final+=""+text

    return final

  def save_bpe_vocab_and_merges(self, vocab_path, bpe_path):
    with open(vocab_path,'w',encoding='utf-8') as f:
      json.dump(self.vocab,f,ensure_ascii=False, indent=2)
    with open(bpe_path,'w',encoding='utf-8') as f:
      json.dump([{'pair':list(pair),'id':id } for pair,id in self.bpe_pairs.items()],f,
                ensure_ascii=False, indent=2)

  def load_bpe_vocab_and_merges(self, vocab_path, bpe_path):
    with open(vocab_path,'r',encoding='utf-8') as f:
      loaded_vocab=json.load(f)
      self.vocab = {int(id):token for id,token in loaded_vocab.items()}
      self.inverse_vocab={token:int(id) for id,token in self.vocab.items()}
    with open(bpe_path,'r',encoding='utf-8') as f:
      bpe=json.load(f)
      for merge in bpe:
        self.bpe_pairs[tuple(merge['pair'])]=merge['id']


In [5]:
tokenizer = BytePairEncodingTokenizer()
tokenizer.train(text, vocab_size=259, special_tokens=("<|endoftext|>", "@"))

In [6]:
tokenizer.vocab

{0: '\x00',
 1: '\x01',
 2: '\x02',
 3: '\x03',
 4: '\x04',
 5: '\x05',
 6: '\x06',
 7: '\x07',
 8: '\x08',
 9: '\t',
 10: '\n',
 11: '\x0b',
 12: '\x0c',
 13: '\r',
 14: '\x0e',
 15: '\x0f',
 16: '\x10',
 17: '\x11',
 18: '\x12',
 19: '\x13',
 20: '\x14',
 21: '\x15',
 22: '\x16',
 23: '\x17',
 24: '\x18',
 25: '\x19',
 26: '\x1a',
 27: '\x1b',
 28: '\x1c',
 29: '\x1d',
 30: '\x1e',
 31: '\x1f',
 32: ' ',
 33: '!',
 34: '"',
 35: '#',
 36: '$',
 37: '%',
 38: '&',
 39: "'",
 40: '(',
 41: ')',
 42: '*',
 43: '+',
 44: ',',
 45: '-',
 46: '.',
 47: '/',
 48: '0',
 49: '1',
 50: '2',
 51: '3',
 52: '4',
 53: '5',
 54: '6',
 55: '7',
 56: '8',
 57: '9',
 58: ':',
 59: ';',
 60: '<',
 61: '=',
 62: '>',
 63: '?',
 64: '@',
 65: 'A',
 66: 'B',
 67: 'C',
 68: 'D',
 69: 'E',
 70: 'F',
 71: 'G',
 72: 'H',
 73: 'I',
 74: 'J',
 75: 'K',
 76: 'L',
 77: 'M',
 78: 'N',
 79: 'O',
 80: 'P',
 81: 'Q',
 82: 'R',
 83: 'S',
 84: 'T',
 85: 'U',
 86: 'V',
 87: 'W',
 88: 'X',
 89: 'Y',
 90: 'Z',
 91: '[',


In [7]:
# test tokenizer on sample text
sample = "Who are you?"
token_ids = tokenizer.encode(sample)
token_ids

[87, 104, 111, 256, 97, 114, 101, 256, 121, 111, 117, 63]

In [8]:
# decode token_ids back to original text
tokenizer.decode(token_ids)

'Who are you?'

## Embedding Layer

We use learned embeddings to convert the input tokens and output tokens to vectors of dimension $d_{model}$. In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation. In the embedding layers, we multiply those weights by $\sqrt{d_{model}}$

In [9]:
# convert sequence as tensor
token_long_tensor = torch.LongTensor(token_ids)
token_long_tensor

tensor([ 87, 104, 111, 256,  97, 114, 101, 256, 121, 111, 117,  63])

In [10]:
class Embedding(nn.Module):
  def __init__(self, num_embeddings, embedding_dim):
    super().__init__()
    self.embedding_dim=embedding_dim
    self.embedding=nn.Embedding(num_embeddings=num_embeddings,
                                embedding_dim=embedding_dim)

  def forward(self, x):
    return np.sqrt(self.embedding_dim)*self.embedding(x)

In [11]:
embedding = Embedding(num_embeddings=len(tokenizer.vocab), embedding_dim=8)

In [12]:
embedded_tokens = embedding(token_long_tensor)
embedded_tokens

tensor([[ 8.5416e-01, -1.7670e+00,  2.1068e+00, -4.1173e+00,  7.5981e-01,
          1.1021e+00, -6.3376e-01,  3.3582e+00],
        [-2.2502e+00, -1.2011e+00,  6.6511e+00, -4.3043e-01, -2.1113e+00,
          1.7882e+00,  3.8869e-01, -4.7831e+00],
        [ 4.4844e+00,  5.2736e+00, -1.9191e+00,  1.1307e+00, -1.1659e+00,
          5.6371e+00,  3.6841e+00, -2.3914e+00],
        [-1.1385e+00, -1.6680e+00,  4.2019e+00,  3.7996e+00,  5.1161e+00,
          1.2670e+00,  1.3824e+00,  1.1164e+00],
        [ 9.2625e-01,  2.0720e-01,  6.7946e+00, -5.2076e+00, -4.7721e+00,
          5.5989e+00, -2.8884e+00,  8.4984e-01],
        [ 2.8368e+00,  7.9613e-01,  1.7057e+00,  4.0691e-01, -1.4522e+00,
          4.5691e-01,  7.3491e-01, -2.1884e+00],
        [-8.0211e-01,  2.3011e+00,  3.7422e+00,  4.1184e+00,  2.2650e+00,
         -2.6388e+00, -2.1120e+00, -3.6979e+00],
        [-1.1385e+00, -1.6680e+00,  4.2019e+00,  3.7996e+00,  5.1161e+00,
          1.2670e+00,  1.3824e+00,  1.1164e+00],
        [-7.0914

In [13]:
embedded_tokens.shape

torch.Size([12, 8])

In [14]:
print(f'Token {token_ids[2]}, Embedding {embedded_tokens[2]}')

Token 111, Embedding tensor([ 4.4844,  5.2736, -1.9191,  1.1307, -1.1659,  5.6371,  3.6841, -2.3914],
       grad_fn=<SelectBackward0>)


In [15]:
print(f'Token {token_ids[9]}, Embedding {embedded_tokens[9]}')

Token 111, Embedding tensor([ 4.4844,  5.2736, -1.9191,  1.1307, -1.1659,  5.6371,  3.6841, -2.3914],
       grad_fn=<SelectBackward0>)


## Positional Encoding

We use positional encoding to make the model aware of the order of sequence. Attention mechanism does not use the concept of position. To solve this we add "positional encodings" to the input embeddings at the bottoms of the encoder and decoder stacks. The positional encodings have the same dimension dmodel as the embeddings, so that the two can be summed. 

We use sin and cosine functions of different frequencies
- $PE(pos,2i) = sin(pos/10000^{2i/dmodel})$
- $PE(pos,2i+1) = cos(pos/10000^{2i/dmodel})$

Each dimension of the positional encoding corresponds to a sinusoid. The wavelengths form a geometric progression from $2\pi$ to 10000 · $2\pi$. 

In [16]:
class PositionalEncoding(nn.Module):
    """
    Standard Sinusoidal Positional Encoding.
    
    wavelength: factor to determine the wavelength in the sinusoidal function.
    """
    def __init__(self, wavelength=10000.):
        super(PositionalEncoding, self).__init__()
        self.wavelength = wavelength

    def forward(self, x):
        """Given a (... x seq_len x embedding_dim) tensor, returns a (seq_len x embedding_dim) tensor."""
        seq_len, embedding_dim = x.shape[-2], x.shape[-1]
        pe = torch.zeros((seq_len, embedding_dim))
        position = torch.arange(seq_len).unsqueeze(1)
        factor = torch.exp(-math.log(self.wavelength) * torch.arange(0, embedding_dim, 2) / embedding_dim)
        pe[:, 0::2] = torch.sin(position * factor)
        pe[:, 1::2] = torch.cos(position * factor)
        return pe

In [21]:
pe = PositionalEncoding()
pe(embedded_tokens)

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
          9.9995e-01,  1.0000e-03,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
          9.9980e-01,  2.0000e-03,  1.0000e+00],
        [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
          9.9955e-01,  3.0000e-03,  1.0000e+00],
        [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
          9.9920e-01,  4.0000e-03,  9.9999e-01],
        [-9.5892e-01,  2.8366e-01,  4.7943e-01,  8.7758e-01,  4.9979e-02,
          9.9875e-01,  5.0000e-03,  9.9999e-01],
        [-2.7942e-01,  9.6017e-01,  5.6464e-01,  8.2534e-01,  5.9964e-02,
          9.9820e-01,  6.0000e-03,  9.9998e-01],
        [ 6.5699e-01,  7.5390e-01,  6.4422e-01,  7.6484e-01,  6.9943e-02,
          9.9755e-01,  6.9999e-03,  9.9998e-01],
        [ 9.8936

In [23]:
encoded_embeddings=pe(embedded_tokens)+embedded_tokens
encoded_embeddings

tensor([[ 8.5416e-01, -7.6701e-01,  2.1068e+00, -3.1173e+00,  7.5981e-01,
          2.1021e+00, -6.3376e-01,  4.3582e+00],
        [-1.4087e+00, -6.6075e-01,  6.7509e+00,  5.6457e-01, -2.1013e+00,
          2.7882e+00,  3.8969e-01, -3.7831e+00],
        [ 5.3937e+00,  4.8574e+00, -1.7205e+00,  2.1107e+00, -1.1459e+00,
          6.6369e+00,  3.6861e+00, -1.3914e+00],
        [-9.9740e-01, -2.6580e+00,  4.4974e+00,  4.7549e+00,  5.1461e+00,
          2.2666e+00,  1.3854e+00,  2.1164e+00],
        [ 1.6945e-01, -4.4645e-01,  7.1840e+00, -4.2865e+00, -4.7321e+00,
          6.5981e+00, -2.8844e+00,  1.8498e+00],
        [ 1.8779e+00,  1.0798e+00,  2.1851e+00,  1.2845e+00, -1.4022e+00,
          1.4557e+00,  7.3991e-01, -1.1885e+00],
        [-1.0815e+00,  3.2613e+00,  4.3068e+00,  4.9437e+00,  2.3250e+00,
         -1.6406e+00, -2.1060e+00, -2.6979e+00],
        [-4.8153e-01, -9.1414e-01,  4.8461e+00,  4.5644e+00,  5.1860e+00,
          2.2646e+00,  1.3894e+00,  2.1164e+00],
        [ 2.8022