<a href="https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/Reformer_3_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **The Reformer - Pushing the limits of language modeling**

***How the Reformer uses less than 8GB of RAM to train on sequences of half a million tokens***

The Reformer model as introduced by [Kitaev, Kaiser et al. (2020)](https://arxiv.org/pdf/2001.04451.pdf) is one of the most memory-efficient transformer models for long sequence modeling as of today.

Recently, long sequence modeling has experienced a surge of interest as can be seen by the many submissions from this year alone - [Beltagy et al. (2020)](https://arxiv.org/abs/2004.05150), [Roy et al. (2020)](https://arxiv.org/abs/2003.05997), [Tay et al.](https://arxiv.org/abs/2002.11296), [Wang et al.](https://arxiv.org/abs/2006.04768) to name  a few. 
The motivation behind long sequence modeling is that many tasks in NLP, *e.g.* summarization, question answering, require the model to process longer input sequences than models, such as BERT, are able to handle. In tasks that require the model to process a large input sequence, long sequence models do not have to cut the input sequence to avoid memory overflow and thus have been shown to outperform standard "BERT"-like models *cf.* [Beltagy et al. (2020)](https://arxiv.org/abs/2004.05150). 

The Reformer pushes the limit of longe sequence modeling by its ability to process up to half a million tokens at once as shown in this [demo](https://github.com/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb). As a comparison, a conventional `bert-base-uncased` model limits the input length to only 512 tokens. In Reformer, each part of the standard transformer architecture is re-engineered to optimize for minimal memory requirement without a significant drop in performance.

The memory improvements can be attributed to **4** features which the Reformer authors introduced to the transformer world:

1.   **Reformer Self-Attention Layer** - *How to efficiently implement self-attention without being restricted to a local context?* => see [this colab](https://colab.research.google.com/drive/15oP52_7W5dRcAnbgX3tYADsu4R3cjMIf?usp=sharing)
2.  **Chunked Feed Forward Layers** - *How to get a better time-memory trade-off for large feed forward layers?* => see [this colab](https://colab.research.google.com/drive/1xKK32Yhda-iYgtoA3eCrnCVuy_lraQR9?usp=sharing)
3.   **Reversible Residual Layers**  - *How to drastically reduce memory consumption in training by a smart residual architecture?*
4.   **Axial Positional Encodings** - *How to make positional encodings usable for extremely large input sequences?*

The goal of this blog post is to give the reader an **in-depth** understanding of each of the four Reformer features mentioned above. While the explanations are focussed on the Reformer, the reader should get a better intuition under which circumstances each of the four features can be effective for other transformer models as well. 
The four sections are only loosely connected, so they can very well be read individually.

Reformer is part of the 🤗Transformers library. For all users of the Reformer, it is advised to go through this very detailed blog post to better understand how the model works and how to correctly set its configuration. All equations are accompanied by their equivalent name for the Reformer config, *e.g.* `config.<param_name>`, so that the reader can quickly relate to the official docs and configuration file.

**Note**: *Axial Positional Encodings* are not explained in the official Reformer paper, but are extensively used in the official codebase. This blog post gives the first in-depth explanation of Axial Positional Encodings.

## **3. Reversible Residual Layers**

Reversible residual layers were first introduced in [N. Gomez et al](https://arxiv.org/abs/1707.04585) and used to reduce memory consumption when training the popular *ResNet* model. Mathematically, reversible residual layers are slightly different 
to "real" residual layers but do not require the activations to be saved during the forward pass, which can drastically reduce memory consumption for training.

### **Reversible Residual Layers in Reformer**

Let's start by investigating why training a model requires 
much more memory than the inference of the model.

When running a model in inference, the required memory equals more or less the memory it takes to compute the **single** largest tensor in the model.
On the other hand, when training a model, the required memory equals more or less the **sum** of all differentiable tensors.

This is not surprising when considering how auto differentiation works in deep learning frameworks. These lecture [slides](https://www.cs.toronto.edu/~rgrosse/courses/csc321_2018/slides/lec10.pdf) by Roger Grosse of the University of Toronto are great to better understand auto differentiation.

In a nutshell, in order to calculate the gradient of a differentiable function (*e.g.* a layer), auto differentiation requires the gradient of the function's output and the function's input and output tensor. While the gradients are dynamically computed and subsequently discarded, the input and output tensors (*a.k.a* activations) of a function are stored during the forward pass.

Alright, let's apply this to a transformer model. A transformer model includes a stack of multiple so-called transformer layers. Each additional transformer layer forces the model to store more activations during the forward pass and thus increases the required memory for training. 
Let's take a more detailed look. A transformer layer essentially consists of two residual layers. The first residual layer represents the *self-attention* mechanism as explained in section 1) and the second residual layer represents the *linear* or feed-forward layers as explained in section 2).

Using the same notation in the previous notebooks [here](https://colab.research.google.com/drive/15oP52_7W5dRcAnbgX3tYADsu4R3cjMIf?usp=sharing) and [here](https://colab.research.google.com/drive/1xKK32Yhda-iYgtoA3eCrnCVuy_lraQR9#scrollTo=GNs6JrxtglSz), the input of a transformer layer *i.e.* $\mathbf{X}$ is first normalized$^{1}$ and subsequently processed by the self-attention layer to get the output $\mathbf{Z} = \text{SelfAttn}(\text{LayerNorm}(\mathbf{X}))$. We will abbreviate these two layers with $G$ so that $\mathbf{Z} = G(\mathbf{X})$. 
Next, the residual $\mathbf{Z}$ is added to the input $\mathbf{\overline{Z}} = \mathbf{Z} + \mathbf{X}$ and the sum is fed into the second residual layer - the two linear layers. $\mathbf{\overline{Z}}$ is processed by a second normalization layer, followed by the two linear layers to get $\mathbf{Y} = \text{Linear}(\text{LayerNorm}(\mathbf{Z} + \mathbf{X}))$. We will abbreviate the second normalization layer and the two linear layers with $F$ yielding $\mathbf{Y} = F(\mathbf{\overline{Z}})$. 
Finally, the residual $\mathbf{Y}$ is added to $\mathbf{\overline{Z}}$ to give the output of the transformer layer $\mathbf{\overline{Y}} = \mathbf{Y} + \mathbf{\overline{Z}}$.

Let's illustrate a complete transformer layer using the example of $\mathbf{x}_1, \ldots, \mathbf{x}_{16}$.

![alt text](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/reformer_benchmark/normal_trans_resnet.png)

To calculate the gradient of *e.g.* the self-attention block $G$, three tensors have to be known beforehand: the gradient $\partial \mathbf{Z}$, the output $\mathbf{Z}$, and the input $\mathbf{X}$. While $\partial \mathbf{Z}$ can be calculated on-the-fly and discarded afterward, the values for $\mathbf{Z}$ and $\mathbf{X}$ have to be calculated and stored during the forward pass since it is not possible to recalculate them easily on-the-fly during backpropagation. Therefore, during the forward pass, large tensor outputs, such as the query-key dot product matrix $\mathbf{Q}\mathbf{K}^T$ or the intermediate output of the linear layers $\mathbf{Y}^{\text{int}}$, have to be stored in memory $^{2}$.

Here, reversible residual layers come to our help. The idea is relatively straight-forward. The residual block is designed in a way so that instead of having to store the input and output tensor of a function, both can easily be recalculated during the backward pass so that no tensor has to be stored in memory during the forward pass. 
This is achieved by using two input streams $\mathbf{X}^{(1)}, \mathbf{X}^{(2)}$, and two output streams $\mathbf{\overline{Y}}^{(1)}, \mathbf{\overline{Y}}^{(2)}$. The first residual $\mathbf{Z}$ is computed by the first output stream $\mathbf{Z} = G(\mathbf{X}^{(1)})$ and subsequently added to the input of the second input stream, so that $\mathbf{\overline{Z}} = \mathbf{Z} + \mathbf{X}^{(2)}$. 
Similarly, the residual $\mathbf{Y} = F(\mathbf{\overline{Z}})$ is added to the first input stream again, so that the two output streams are defined by $\mathbf{Y}^{(1)} = \mathbf{Y} + \mathbf{X}^{(1)}$ and $\mathbf{Y}^{(2)} = \mathbf{X}^{(2)} + \mathbf{Z} = \mathbf{\overline{Z}}$.

The reversible transformer layer can be visualized for $\mathbf{x}_1, \ldots, \mathbf{x}_{16}$ as follows.

![alt text](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/reformer_benchmark/rev_trans_resnet.png)

As can be seen, the outputs $\mathbf{\overline{Y}}^{(1)}, \mathbf{\overline{Y}}^{(2)}$ are calculated in a very similar way than $\mathbf{\overline{Y}}$ of the non-reversible layer, but they are mathematically different. The authors of Reformer observe in some initial experiments that the performance of a reversible transformer model matches the performance of a standard transformer model. 
The first visible difference to the standard transformer layer is that there are two input streams and output streams $^{3}$, which at first slightly increases the required memory for both the forward pass.
The two-stream architecture is crucial though for not having to save any activations during the forward pass. Let's explain. For backpropagation, the reversible transformer layer has to calculate the gradients $\partial G$ and $\partial F$. In addition to the gradients $\partial \mathbf{Y}$ and $\partial \mathbf{Z}$ which can be calculated on-the-fly, the tensor values $\mathbf{Y}$, $\mathbf{\overline{Z}}$ have to be known for $\partial F$ and the tensor values $\mathbf{Z}$ and $\mathbf{X}^{(1)}$ for $\partial G$ to make auto-differentiation work.

If we assume to know $\mathbf{\overline{Y}}^{(1)}, \mathbf{\overline{Y}}^{(2)}$, it can easily be depicted from the graph that one can calculate $\mathbf{X}^{(1)}, \mathbf{X}^{(2)}$ as follows. $\mathbf{X}^{(1)} = F(\mathbf{\overline{Y}}^{(1)}) - \mathbf{\overline{Y}}^{(1)}$. Great, now that $\mathbf{X}^{(1)}$ is known, $\mathbf{X}^{(2)}$ can be computed by $\mathbf{X}^{(2)} = \mathbf{\overline{Y}}^{(1)} - G(\mathbf{X}^{(1)})$. Alright now, $\mathbf{Z}$ and $\mathbf{Y}$ are trivial to compute via $\mathbf{Y} = \mathbf{\overline{Y}}^{(1)} - \mathbf{X}^{(1)}$ and $\mathbf{Z} = \mathbf{\overline{Y}}^{(2)} - \mathbf{X}^{(2)}$. So as a conclusion, if only the outputs $\mathbf{\overline{Y}}^{(1)}, \mathbf{\overline{Y}}^{(2)}$ of the **last** reversible transformer layer are stored during the forward pass, all other relevant activations can be derived by making use of $G$ and $F$ during the backward pass and passing $\mathbf{X}^{(1)}$ and $\mathbf{X}^{(2)}$. The overhead of two forward passes of $G$ and $F$ per reversible transformer layer during the backpropagation is traded against not having to store any activations during the forward pass. Not a bad deal!

**Note**: Since recently, major deep learning frameworks have released code that allows to store only certain activations and recompute larger ones during the backward propagation (Tensoflow [here](https://www.tensorflow.org/api_docs/python/tf/recompute_grad) and PyTorch [here](https://pytorch.org/docs/stable/checkpoint.html)). For standard reversible layers, this still means that at least one activation has to be stored for each transformer layer, but by defining which activations can dynamically be recomputed a lot of memory can be saved.

---
$^{1}$ In the previous two sections, we have omitted the layer norm layers preceding both the self-attention layer and the linear layers. The reader should know that both $\mathbf{X}$ and $\mathbf{\overline{Z}}$ are both processed by layer normalization before being fed into self-attention and the linear layers respectively.
$^{2}$ While in the design the dimension of $\mathbf{Q}\mathbf{K}$ is written as $n \times n$, in a *LSH self-attention* or *local self-attention* layer the dimension would only be $n \times l_c \times n_h$ or $n \times l_c$ respectively with $l_c$ being the chunk length and $n_h$ the number of hashes
$^{3}$ In the first reversible transformer layer $\mathbf{X}^{(2)}$ is set to be equal to $\mathbf{X}^{(1)}$.


### **Benchmark**

In order to measure the effect of reversible residual layers, we will compare the memory consumption of BERT with Reformer in training for an increasing number of layers.

In [1]:
#@title Installs and Imports
# pip installs
!pip -qq install git+https://github.com/huggingface/transformers.git
!pip install -qq py3nvml

from transformers import ReformerConfig, BertConfig, PyTorchBenchmark, PyTorchBenchmarkArguments

[K     |████████████████████████████████| 3.0MB 3.5MB/s 
[K     |████████████████████████████████| 1.1MB 35.7MB/s 
[K     |████████████████████████████████| 890kB 36.9MB/s 
[?25h  Building wheel for transformers (setup.py) ... [?25l[?25hdone
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 61kB 2.2MB/s 
[?25h

Let's measure the required memory for the standard `bert-base-uncased` BERT model by increasing the number of layers from 4 to 12.

In [2]:
config_4_layers_bert = BertConfig.from_pretrained("bert-base-uncased", num_hidden_layers=4)
config_8_layers_bert = BertConfig.from_pretrained("bert-base-uncased", num_hidden_layers=8)
config_12_layers_bert = BertConfig.from_pretrained("bert-base-uncased", num_hidden_layers=12)
benchmark_args = PyTorchBenchmarkArguments(sequence_lengths=[512], batch_sizes=[8], models=["Bert-4-Layers", "Bert-8-Layers", "Bert-12-Layers"], training=True, no_inference=True, no_speed=True, no_env_print=True)
benchmark = PyTorchBenchmark(configs=[config_4_layers_bert, config_8_layers_bert, config_12_layers_bert], args=benchmark_args)
result = benchmark.run()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…


1 / 3
2 / 3
3 / 3

--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length    Memory in MB 
--------------------------------------------------------------------------------
        Bert-4-Layers                8              512             4103     
        Bert-8-Layers                8              512             5759     
        Bert-12-Layers               8              512             7415     
--------------------------------------------------------------------------------


It can be seen that adding a single layer of BERT linearly increases the required memory by ca. 400MB.

In [3]:
config_4_layers_reformer = ReformerConfig.from_pretrained("google/reformer-enwik8", num_hidden_layers=4, num_hashes=1)
config_8_layers_reformer = ReformerConfig.from_pretrained("google/reformer-enwik8", num_hidden_layers=8, num_hashes=1)
config_12_layers_reformer = ReformerConfig.from_pretrained("google/reformer-enwik8", num_hidden_layers=12, num_hashes=1)
benchmark_args = PyTorchBenchmarkArguments(sequence_lengths=[512], batch_sizes=[8], models=["Reformer-4-Layers", "Reformer-8-Layers", "Reformer-12-Layers"], training=True, no_inference=True, no_speed=True, no_env_print=True)
benchmark = PyTorchBenchmark(configs=[config_4_layers_reformer, config_8_layers_reformer, config_12_layers_reformer], args=benchmark_args)
result = benchmark.run()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1279.0, style=ProgressStyle(description…


1 / 3
2 / 3
3 / 3

--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length    Memory in MB 
--------------------------------------------------------------------------------
      Reformer-4-Layers              8              512             4607     
      Reformer-8-Layers              8              512             4987     
      Reformer-12-Layers             8              512             5367     
--------------------------------------------------------------------------------


For Reformer, on the other hand, adding a layer adds significantly less memory in practice. Adding a single layer increases the required memory on average by less than 100MB so that a much larger 12-Layer `reformer-enwik8` model requires less memory than a 12-Layer `bert-base-uncased` model.