This project investigates the integration of a recurrent connection within a pre-trained transformer model (DistilGPT-2) using PyTorch. The main goal is to explore whether adding a recurrent connection to a pure attention-based model is worth the expense of increased parameters, longer training times, and other architectural constraints.
Before running the project, ensure the following dependencies are installed:
- pytorch >= 1.10
- torch
- tqdm
- numpy
- wandb
You can install these dependencies by using the following command:
python setup.py bdist_wheel sdists
The project is structured as follows:
This section involves loading and preprocessing the data before feeding it to the model.
The initial step is to build a basic model as a starting point for further enhancements. In this step, basic and multi-head self-attention layers are added on top of the classifier model.
Next, a complete transformer is added on top of the classifier to introduce a more complex model.
A generator transformer is implemented in this section, enhancing the model's capabilities.
The main focus of this project, a recurrent connection, is integrated into the generator transformer.
Finally, the results and metrics (loss, gradient clipping norm, and perplexity) are tracked using weights & biases for evaluation.
This project aims to understand the impact of a recurrent connection in a pre-trained transformer model. By comparing the performance, training time, and architectural constraints, we can determine whether the addition of recurrent connections is beneficial in the context of attention-based models.