# Vanilla Transformer with PyTorch

This notebook is used to test the implementation of the vanilla transformer model from scratch using PyTorch. The model is saved under the `models` directory.

In [1]:
import sys
from path_utils import add_parent_path_to_sys_path
# add the parent directory to the sys path so that we can import the models
current_path = sys.path[0]
add_parent_path_to_sys_path(current_path, verbose=False)

# import the models
from models.vanilla_transformers import Transformer

import torch

Path added to the sys path.


In [2]:
# setting up the model
source_vocab_size = 4096
target_vocab_size = 4096
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 128
dropout = 0.1

transformer = Transformer(source_vocab_size,
                          target_vocab_size,
                          d_model,
                          num_heads,
                          num_layers,
                          d_ff,
                          max_seq_length,
                          dropout)
print(transformer)

Transformer(
  (encoder_embedding): Embedding(4096, 512)
  (decoder_embedding): Embedding(4096, 512)
  (positional_encoding): PositionalEncoding()
  (transformer_encoder): ModuleList(
    (0-5): 6 x TransformerEncoderLayer(
      (self_attn): MultiHeadAttention(
        (W_q): Linear(in_features=512, out_features=512, bias=True)
        (W_k): Linear(in_features=512, out_features=512, bias=True)
        (W_v): Linear(in_features=512, out_features=512, bias=True)
        (W_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (feedforward): PositionWiseFeedForward(
        (fc1): Linear(in_features=512, out_features=2048, bias=True)
        (fc2): Linear(in_features=2048, out_features=512, bias=True)
        (activation): ReLU()
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (transformer_decoder): ModuleList(
    

In [3]:
# generate fake data
source_data = torch.randint(1, # low
                            source_vocab_size, # high
                            (1, max_seq_length), # size
                            ) # (batch_size, max_seq_length)
target_data = torch.randint(1, # low
                            target_vocab_size, # high
                            (1, max_seq_length), # size
                            ) # (batch_size, max_seq_length)

In [4]:
# setting up the loss function
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
# setting up the optimizer
optimizer = torch.optim.Adam(transformer.parameters(), # parameters
                             lr=0.0001, # learning rate
                             betas=(0.9, 0.98), # betas
                             eps=1e-9, # eps
                             weight_decay=0.0001, # weight decay
                             )

# training loop
# running 10 epochs to test the training loop
transformer.train()
for epoch in range(1, 11):
    # zero out the gradients
    optimizer.zero_grad()
    # forward pass
    output = transformer(source_data, 
                         target_data[:, :-1], # input-output mismatch
                        )
    # calculate the loss
    loss = criterion(output.contiguous().view(-1, target_vocab_size), # (batch_size * max_seq_length, target_vocab_size)
                     target_data[:, 1:].contiguous().view(-1), # shift the target data by 1
                     )
    # backward pass
    loss.backward()
    # update the weights
    optimizer.step()
    # print the loss
    print(f'Epoch: {epoch}, Loss: {loss.item()}')

Epoch: 1, Loss: 8.436713218688965
Epoch: 2, Loss: 7.639357566833496
Epoch: 3, Loss: 7.231591701507568
Epoch: 4, Loss: 7.009538173675537
Epoch: 5, Loss: 6.817314624786377
Epoch: 6, Loss: 6.585111618041992
Epoch: 7, Loss: 6.248195171356201
Epoch: 8, Loss: 5.939205169677734
Epoch: 9, Loss: 5.512218475341797
Epoch: 10, Loss: 5.180314064025879
