<h1><center>The Transformer Blog: fairseq edition</center></h1>


The Transformer architecture was presented in ["Attention is All You Need"](https://arxiv.org/abs/1706.03762) and introduced a new architecture for many NLP tasks. In this post we present an explanation of the Transformer architecture focusing on the [fairseq](https://github.com/pytorch/fairseq) implementation. We believe this could be useful for researchers and developers working on this framework.

The blog is based on [The annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html), [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/) and [Fairseq Transformer, BART](https://yinghaowang.xyz/technology/2020-03-14-FairseqTransformer.html) blogs.

# Background <a class="anchor" id="background"></a>

The Transformer was introduced as an alternative model to RNNs and ConvNets that compute representations of its inputs in a constant number of operations. This is made thanks to the self-attention module, which is a kind of attention mechanism that relies on the input sequence of a single data sample to make representations of it.


# Model Architecture <a class="anchor" id="model-architecture"></a>

The Transformer is based on a stack of encoders and another stack of decoders. The encoder maps an input sequence of token representations $(x_1, ..., x_{src\_length})$ to a sequence of continuous representations $\mathbf{encoder\_out} = (encoder\_out_1, ..., encoder\_out_{src\_length})$. Given $\mathbf{encoder\_out}$, the decoder then generates an output sequence $\mathcal{Y} = (output_0,...,output_{tgt\_length})$ of symbols one element at a time. At each step the model is auto-regressive, consuming the previously generated symbols as additional input when generating the next token.

<img src="The_Transformer_Blog_files/transformer_illustrated.png" style="width:650px;height:370px;" align="center"/>

This model is implemented in fairseq as <code class="language-plaintext highlighter-rouge">TransformerModel</code> in [fairseq/models/transformer.py](https://github.com/pytorch/fairseq/blob/master/fairseq/models/transformer.py).

In [None]:
class TransformerModel(FairseqEncoderDecoderModel):
...
  def forward(
          self,
          src_tokens,
          src_lengths,
          prev_output_tokens,
          return_all_hiddens: bool = True,
          features_only: bool = False,
          alignment_layer: Optional[int] = None,
          alignment_heads: Optional[int] = None,
      ):
          """
          Run the forward pass for an encoder-decoder model.

          Copied from the base class, but without ``**kwargs``,
          which are not supported by TorchScript.
          """
          encoder_out = self.encoder(
              src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
          )
          decoder_out = self.decoder(
              prev_output_tokens,
              encoder_out=encoder_out,
              features_only=features_only,
              alignment_layer=alignment_layer,
              alignment_heads=alignment_heads,
              src_lengths=src_lengths,
              return_all_hiddens=return_all_hiddens,
          )
          return decoder_out

## Encoder and Decoder Stacks <a class="anchor" id="encoder-and-decoder"></a>

## Encoder

The encoder (<code class="language-plaintext highlighter-rouge">TransformerEncoder</code>) is composed of a stack of $N=encoder\_layers$ identical layers.

The encoder recieves a list of tokens $src\_tokens=(token_{0},...,token_{src\_len})$ which are then converted to continuous vector representions <code class="language-plaintext highlighter-rouge">x = self.forward_embedding(src_tokens, token_embeddings)</code>, which is the sum of the (scaled) embedding lookup and the positional embedding: <code class="language-plaintext highlighter-rouge">x = embed + self.embed_positions(src_tokens)</code>.

<img src="The_Transformer_Blog_files/encoder_input.png" style="width:480px;height:150px;" align="center"/>

From now on, let's consider $\mathcal{X}^L = (x_{0},\cdots,x_{src\_length})$ as the $L$ encoder layer input. $X^{1}$ refers then to the vectors representation of the input sequence tokens of the first layer, after computing <code class="language-plaintext highlighter-rouge">self.forward_embedding</code>. Note that although $\mathcal{X}^L$ is represented in fairseq as a tensor of shape <code class="language-plaintext highlighter-rouge">src_length x batch x encoder_embed_dim</code>, for the shake of simplicity, we omit the second dimension in the upcoming mathematical notation and just consider it as a <code class="language-plaintext highlighter-rouge">src_length x encoder_embed_dim</code> matrix.

In [None]:
class TransformerEncoder(FairseqEncoder):
...
  def forward(
        self,
        src_tokens,
        src_lengths,
        return_all_hiddens: bool = False,
        token_embeddings: Optional[torch.Tensor] = None,
    ):

        x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)

        # batch x src_lengths x encoder_embed_dim
        #                     -> src_lengths x batch x encoder_embed_dim
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx)

        encoder_states = [] if return_all_hiddens else None

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask)
            if return_all_hiddens:
                assert encoder_states is not None
                encoder_states.append(x)

        if self.layer_norm is not None:
            x = self.layer_norm(x)

        return EncoderOut(
            encoder_out=x,  # src_lengths x batch x encoder_embed_dim
            encoder_padding_mask=encoder_padding_mask,
            encoder_embedding=encoder_embedding,
            encoder_states=encoder_states, # List[src_lengths x batch x encoder_embed_dim]
            src_tokens=None,
            src_lengths=None,
        )

This returns a NamedTuple object <code class="language-plaintext highlighter-rouge">encoder_out</code>.

* encoder_out: of shape <code class="language-plaintext highlighter-rouge">src_length x batch x encoder_embed_dim</code>, the last layer encoder's embedding which, as we will see, is used by the Decoder. Note that is the same as $\mathcal{X}^{N+1}$.
* encoder_padding_mask: of shape <code class="language-plaintext highlighter-rouge">batch x src_length</code>. Binary ByteTensor where padding elements are indicated by 1.
* encoder_embedding: of shape <code class="language-plaintext highlighter-rouge">src_length x batch x encoder_embed_dim</code>, the words (scaled) embedding lookup.
* encoder_states: of shape <code class="language-plaintext highlighter-rouge">list[src_length x batch x encoder_embed_dim]</code>, intermediate enocoder layer's output.


### Encoder Layer

The previous snipped of code shows a loop over the layers of the Encoder block <code class="language-plaintext highlighter-rouge">for layer in self.layers</code>. This layer is implemented in fairseq in <code class="language-plaintext highlighter-rouge">class TransformerEncoderLayer(nn.Module)</code> inside [fairseq/modules/transformer_layer.py](https://github.com/pytorch/fairseq/blob/master/fairseq/modules/transformer_layer.py) and computes the following operations:

<img src="The_Transformer_Blog_files/encoder_layer_transformer_full.png" style="width:400px;height:350px;" align="center"/>

The input of the encoder layer is passed through the self-attention module <code class="language-plaintext highlighter-rouge">self.self_attn</code>, dropout (<code class="language-plaintext highlighter-rouge">self.dropout_module(x)</code>) is then applied before getting to the Add & Normalize module (made of a residual connection <code class="language-plaintext highlighter-rouge">self.residual_connection(x, residual)</code> and a layer normalization (LayerNorm) <code class="language-plaintext highlighter-rouge">self.self_attn_layer_norm(x)</code>

In [None]:
class TransformerEncoderLayer(nn.Module):
...
  def forward(self, x, encoder_padding_mask, attn_mask: Optional[Tensor] = None):
    if attn_mask is not None:
      attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), -1e8)

    residual = x
    if self.normalize_before:
        x = self.self_attn_layer_norm(x)
    x, _ = self.self_attn(
        query=x,
        key=x,
        value=x,
        key_padding_mask=encoder_padding_mask,
        attn_mask=attn_mask,
    )
    x = self.dropout_module(x)
    x = self.residual_connection(x, residual)
    if not self.normalize_before:
        x = self.self_attn_layer_norm(x)

Then, the result is passed through a position-wise feed-forward network composed by two fully connected layers, <code class="language-plaintext highlighter-rouge">fc1</code> and <code class="language-plaintext highlighter-rouge">fc2</code> with a ReLU activation in between (<code class="language-plaintext highlighter-rouge">self.activation_fn(self.fc1(x))</code>) and dropout <code class="language-plaintext highlighter-rouge">self.dropout_module(x)</code>.

$$\mathrm{Feed Forward}(x)=\max(0, xW_1 + b_1) W_2 + b_2$$



In [None]:
    residual = x
    if self.normalize_before:
        x = self.final_layer_norm(x)

    x = self.activation_fn(self.fc1(x))
    x = self.activation_dropout_module(x)
    x = self.fc2(x)
    x = self.dropout_module(x)
       

Finally, a residual connection is made before another layer normalization layer.

In [None]:
    x = self.residual_connection(x, residual)
    if not self.normalize_before:
        x = self.final_layer_norm(x)
    return x

#### Self-attention

As we have seen, the input of each encoder layer is firstly passed through a self-attention layer ([fairseq/modules/multihead_attention.py](https://github.com/pytorch/fairseq/blob/master/fairseq/modules/multihead_attention.py))

In [None]:
class MultiheadAttention(nn.Module):
...
  def forward(
        self,
        query,
        key: Optional[Tensor],
        value: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        need_weights: bool = True,
        static_kv: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_softmax: bool = False,
        need_head_weights: bool = False,
    ) -> Tuple[Tensor, Optional[Tensor]]:

Each encoder layer input $\mathcal{X}^L$, shown as <code class="language-plaintext highlighter-rouge">query</code>below since three copies of $query$ are passed to the self-attention module, is multiplied by three weight matrices learned during the training process: $W^{Q}, W^{K}$ and $W^{V}$, obtaining $Q$, $K$ and $V$. Each row of this output matrices represents the query, key and value vectors of each token in the sequence, represented as $q$, $k$ and $v$ in the formulas that follow.

In [None]:
    if self.self_attention:
      q = self.q_proj(query) # Q
      k = self.k_proj(query) # K
      v = self.v_proj(query) # V
    q *= self.scaling

The self-attention module does the following operation:

$$
\mathrm{softmax}(\frac{QK^\top}{\sqrt{d_{k}}})V
$$

In [None]:
    attn_weights = torch.bmm(q, k.transpose(1, 2)) # QK^T multiplication

Given a token in the input sequence, $i \in \mathcal{X}^L$, its query vector $q_{i}$ is passed to the self-attention function. Then, by means of dot products, scalar values (scores) are obtained between the query vector $q_{i}$ and every key vector of the input sequence $k_{j} \forall j \in \mathcal{X}^L$. The intuition is that this performs a similarity operation, similar queries and keys vectors will yield higher scores.

This scores represents how much attention is paid by the self-attention layer to other parts of the sequence when encoding $i$. By multiplying $q_{i}$ by the matrix $K^{T}$, a list of <code class="language-plaintext highlighter-rouge">src_length</code> scores is output. The scores are then passed through a softmax function giving bounded values:

$$\alpha_{i} = \text{softmax}(\frac{\mathbf{q}_i {K}^\top}{\sqrt{d_k}})
= \frac{\exp(\frac{\mathbf{q}_i {K}^\top}{\sqrt{d_k}})}{ \sum_{j \in \mathcal{X}^L} \exp(\frac{\mathbf{q}_i k_{j}^\top}{\sqrt{d_k}})}$$

The division by the square root of the dimension of the key vectors $d_{k}$ (for getting more stable gradients) is done previously <code class="language-plaintext highlighter-rouge">q *= self.scaling</code> instead in fairseq.


Given the sentence "the nice cat walks away from us" for the token $i=\text{from}$, its corresponding attention weights $\alpha_{i}$ for every other token $j$ in the input sequence could be:

<img src="The_Transformer_Blog_files/probs.jpg" style="width:600px;height:250px;" align="center"/>

In [None]:
    attn_weights_float = utils.softmax(
                attn_weights, dim=-1, onnx_trace=self.onnx_trace
            )
    attn_weights = attn_weights_float.type_as(attn_weights)

Once we have normalized scores for every pair of tokens $\{i,j\}$, we multiply these weights by the value vector $v_{j} \forall j \in \mathcal{X}$ (each row in matrix $V$) and finally sum up those vectors:

$$
z_{i} = \sum_{j \in \mathcal{X}}\alpha_{i,j}v_{j}
$$

Where $z_{i}$ represents row $i$ of $Z$. By doing the matrix multiplication of the attention weight matrix <code class="language-plaintext highlighter-rouge">attn_weights</code> and $V$, $\mathrm{softmax}(\frac{QK^{T}}{\sqrt{d_k}})V$, we directly get matrix $Z$.

In [None]:
    attn_probs = self.dropout_module(attn_weights)
    assert v is not None
    attn = torch.bmm(attn_probs, v)

This process is done in parallel in each of the self-attention heads. So, in total <code class="language-plaintext highlighter-rouge">encoder_attention_heads</code> matrices are output. Each head has its own $W^{Q}, W^{K}$ and $W^{V}$ weight matrices which are randomly initialized, so the result leads to different representation subspaces in each of the self-attention heads.

The output matrices $Z$ of every self-attention head are concatenated into a single one to which a linear transformation $W^{O}$ (<code class="language-plaintext highlighter-rouge">self.out_proj</code>) is applied, $$attn = Concat(Z_{head_{i}},\cdots,Z_{head_{h}})W^{O}$$

In [None]:
    attn = self.out_proj(attn)
    attn_weights: Optional[Tensor] = None
    if need_weights:
        attn_weights = attn_weights_float.view(
            bsz, self.num_heads, tgt_len, src_len
        ).transpose(1, 0)
        if not need_head_weights:
            # average attention weights over heads
            attn_weights = attn_weights.mean(dim=0)

    return attn, attn_weights

Notice that <code class="language-plaintext highlighter-rouge">attn_probs</code> has dimensions (bsz * self.num_heads, tgt_len, src_len)


To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension $d_{\text{model}}=$<code class="language-plaintext highlighter-rouge">encoder_embed_dim</code>.

## Decoder

The decoder is composed of a stack of $N=decoder\_layers$ identical layers.

The <code class="language-plaintext highlighter-rouge">TransformerDecoder</code> inherits from <code class="language-plaintext highlighter-rouge">FairseqIncrementalDecoder</code>. It differs from the encoder in that it performs incremental decoding. This means that at each time step a forward pass is done through the decoder generating one output token, which is then fed as input to the next time step decoding forward process. Especifically, it takes the encoder's output <code class="language-plaintext highlighter-rouge">encoder_out.encoder_out</code> as $key$ and $value$ matrices (in every decoder layer) and <code class="language-plaintext highlighter-rouge">prev_output_tokens</code> to generate one feature vector per target token at each time step (<code class="language-plaintext highlighter-rouge">tgt_len = 1</code> in each forward pass). This feature vector is then passed through a linear layer together with a softmax activation function <code class="language-plaintext highlighter-rouge">self.output_layer(x)</code> to get a probability distribution over the target language vocabulary.

Following the beam search algorithm, top <code class="language-plaintext highlighter-rouge">beam</code> hypothesis are chosen and inserted as input of the decoder (<code class="language-plaintext highlighter-rouge">prev_output_tokens</code>) for the next time step.

<img src="The_Transformer_Blog_files/transformer_decoding.gif"  style="width:750px;height:400px;" align="center"/>

In [None]:
class TransformerDecoder(FairseqIncrementalDecoder):
...
  def forward(
        self,
        prev_output_tokens,
        encoder_out: Optional[EncoderOut] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        features_only: bool = False,
        full_context_alignment: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
        src_lengths: Optional[Any] = None,
        return_all_hiddens: bool = False,
    ):
        
    x, extra = self.extract_features(
        prev_output_tokens,
        encoder_out=encoder_out,
        incremental_state=incremental_state,
        full_context_alignment=full_context_alignment,
        alignment_layer=alignment_layer,
        alignment_heads=alignment_heads,
    )
    if not features_only:
        x = self.output_layer(x)
    return x, extra

In [None]:
def extract_features(
        self,
        prev_output_tokens,
        encoder_out: Optional[EncoderOut] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        full_context_alignment: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
    ):
    return self.extract_features_scriptable(
        prev_output_tokens,
        encoder_out,
        incremental_state,
        full_context_alignment,
        alignment_layer,
        alignment_heads,
    )

In the first time step, <code class="language-plaintext highlighter-rouge">prev_output_tokens</code> represents the beginning of sentence (BOS) token index. Its embedding <code class="language-plaintext highlighter-rouge">self.embed_tokens(prev_output_tokens)</code> enters the decoder as a tensor <code class="language-plaintext highlighter-rouge">beam*batch x tgt_len x encoder_embed_dim</code>. As in the decoder, to the input token is added a positional embedding.

In [None]:
def extract_features_scriptable(
        self,
        prev_output_tokens,
        encoder_out: Optional[EncoderOut] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        full_context_alignment: bool = False,
        alignment_layer: Optional[int] = None,
        alignment_heads: Optional[int] = None,
    ):
  ..
    
    x = self.embed_scale * self.embed_tokens(prev_output_tokens)
    if positions is not None:
            x += positions
    attn: Optional[Tensor] = None
    inner_states: List[Optional[Tensor]] = [x]
    for idx, layer in enumerate(self.layers):
        if incremental_state is None and not full_context_alignment:
            self_attn_mask = self.buffered_future_mask(x)
        else:
            self_attn_mask = None

        x, layer_attn, _ = layer(
            x,
            encoder_out.encoder_out if encoder_out is not None else None,
            encoder_out.encoder_padding_mask if encoder_out is not None else None,
            incremental_state,
            self_attn_mask=self_attn_mask,
            self_attn_padding_mask=self_attn_padding_mask,
            need_attn=bool((idx == alignment_layer)),
            need_head_weights=bool((idx == alignment_layer)),
        )
        inner_states.append(x)
        if layer_attn is not None and idx == alignment_layer:
            attn = layer_attn.float().to(x)

    if attn is not None:
        if alignment_heads is not None:
            attn = attn[:alignment_heads]

        # average probabilities over heads
        attn = attn.mean(dim=0)

    if self.layer_norm is not None:
        x = self.layer_norm(x)

    # T x B x C -> B x T x C
    x = x.transpose(0, 1)

    if self.project_out_dim is not None:
        x = self.project_out_dim(x)

    return x, {"attn": [attn], "inner_states": inner_states}

### Decoder Layer

The previous snipped of code shows a loop over the layers of the Decoder block <code class="language-plaintext highlighter-rouge">for idx, layer in enumerate(self.layers):</code>. This layer is implemented in fairseq in <code class="language-plaintext highlighter-rouge">class TransformerDecoderLayer(nn.Module)</code> inside [fairseq/modules/transformer_layer.py](https://github.com/pytorch/fairseq/blob/master/fairseq/modules/transformer_layer.py) and computes the following operations:

<img src="The_Transformer_Blog_files/Decoder.png"  style="width:350px;height:280px;" align="center"/>

In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer (Encoder-Decoder Attention), which performs multi-head attention over the output of the encoder stack as $key$ and $value$ matrices and the ouput of the self-attention module.  Similar to the encoder, it employs residual connections around each of the sub-layers, followed by layer normalization.

In [None]:
class TransformerDecoderLayer(nn.Module):
    ..
    def forward(
        self,
        x,
        encoder_out: Optional[torch.Tensor] = None,
        encoder_padding_mask: Optional[torch.Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        prev_self_attn_state: Optional[List[torch.Tensor]] = None,
        prev_attn_state: Optional[List[torch.Tensor]] = None,
        self_attn_mask: Optional[torch.Tensor] = None,
        self_attn_padding_mask: Optional[torch.Tensor] = None,
        need_attn: bool = False,
        need_head_weights: bool = False,
    ):

        ...
        
        residual = x
        if self.normalize_before:
            x = self.self_attn_layer_norm(x)
        if prev_self_attn_state is not None:
            prev_key, prev_value = prev_self_attn_state[:2]
            saved_state: Dict[str, Optional[Tensor]] = {
                "prev_key": prev_key,
                "prev_value": prev_value,
            }
            if len(prev_self_attn_state) >= 3:
                saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
            assert incremental_state is not None
            self.self_attn._set_input_buffer(incremental_state, saved_state)
        _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
        
        

During incremental decoding, <code class="language-plaintext highlighter-rouge">prev_output_tokens</code> $(output_{0},...,output_{t-1})$ enter the self-attention module as key and value vectors. However, only the last time step output token, $output_{t-1}$, enters as a query vector. So, query vectors now have one element in the second dimension, that is, there is no need to use matrix $Q$.
As before, scalar values (scores) are obtained between the query vector $q_{t-1}$ and every key vector of the whole previous tokens sequence $k_{j} \forall j \in \mathcal{Y<t}$, where $\mathcal{Y}$ represents the decoder output sequence.

<img src="The_Transformer_Blog_files/incremental_decoding.png"   align="center"/>

$$
z_{t} = \sum_{j \in \mathcal{Y<t}}\alpha_{t,j}v_{j}
$$

Now, just one vector $z_{t}$ is generated at each time step by each head as a weighted average of the $v$ vectors. The 

In [None]:
...

        y = x

        x, attn = self.self_attn(
            query=x,
            key=y,
            value=y,
            key_padding_mask=self_attn_padding_mask,
            incremental_state=incremental_state,
            need_weights=False,
            attn_mask=self_attn_mask,
        )
        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.self_attn_layer_norm(x)

        if self.encoder_attn is not None and encoder_out is not None:
            residual = x
            if self.normalize_before:
                x = self.encoder_attn_layer_norm(x)
            if prev_attn_state is not None:
                prev_key, prev_value = prev_attn_state[:2]
                saved_state: Dict[str, Optional[Tensor]] = {
                    "prev_key": prev_key,
                    "prev_value": prev_value,
                }
                if len(prev_attn_state) >= 3:
                    saved_state["prev_key_padding_mask"] = prev_attn_state[2]
                assert incremental_state is not None
                self.encoder_attn._set_input_buffer(incremental_state, saved_state)

            x, attn = self.encoder_attn(
                query=x,
                key=encoder_out,
                value=encoder_out,
                key_padding_mask=encoder_padding_mask,
                incremental_state=incremental_state,
                static_kv=True,
                need_weights=need_attn or (not self.training and self.need_attn),
                need_head_weights=need_head_weights,
            )
            x = self.dropout_module(x)
            x = self.residual_connection(x, residual)
            if not self.normalize_before:
                x = self.encoder_attn_layer_norm(x)

        residual = x
        if self.normalize_before:
            x = self.final_layer_norm(x)

        x = self.activation_fn(self.fc1(x))
        x = self.activation_dropout_module(x)
        x = self.fc2(x)
        x = self.dropout_module(x)
        x = self.residual_connection(x, residual)
        if not self.normalize_before:
            x = self.final_layer_norm(x)
...
        return x, attn, None