To get intimately familiar with the nuts and bolts of transformers I decided to implement the original architecture from Attention Is All You Need.
This repo accompanies the blogpost Implementing a Transformer From Scratch: 7 surprising things you might not know about the Transformer. I wrote this blogpost to highlight things that I learned in the process and that I found particularly surprising or insightful.
Each Python file contains one or more classes related to the transformer. Additionally, at the bottom of each file you can find unit tests for that class. These unit tests are executed simply by running the file (e.g. python transformer.py
), and are run on every push to this repo using Github Actions. They serve two purposes. First, they are sanity checks that verify whether the class is doing what it should. Second, they are examples for how to use each class.
In practice, of course, please do use the official PyTorch implementation. This repo is by no means meant as an alternative: it is meant to help me (and hopefully you) better understand how transformers are actually implemented.
This repo contains the following files and features:
- The simplest imaginable vocabulary (vocabulary.py).
- The simplest imaginable (batch) tokenizer (vocabulary.py).
- TransformerEncoder and EncoderBlock classes (encoder.py).
- TransformerDecoder and DecoderBlock classes (decoder.py).
- Transformer main class (transformer.py).
- Train script with a unit test that (over)fits a synthetic copy dataset (train.py).
- MultiHeadAttention class with scaled dot product and masking (multi_head_attention.py).
- SinusoidEncoding class for positional encodings (positional_encodings.py).
- Utility functions to construct masks and batches (utils.py).
- Learning rate scheduler (lr_scheduler.py).
- Basic unit tests for each class. Running a file (e.g.
python encoder.py
) will execute its unit tests. - Type checking using mypy.
- Code formatting using black.
- Automatic execution of unit tests and type checks using Github Actions.
- No dependencies except Python 3.9 and PyTorch 1.9.1 (though basically any version should work).