# Encoder Layer

The transformer encoder is consist of n identical encoder layers according to the paper. In this notebook, we will build the encoder layer and write basic tests for it.

Below is a diagram of the encoder layer.

![encoder](./photos/Screenshot%20from%202022-09-07%2023-58-57.png)

In [3]:
import torch
from torch import nn
from vit.multiheaded_attentions import MultiHeadAttention

# Feed Forward Network (FFN) Implementation
Let's first implement the feed forward network (FFN). According the the paper, the FFN is fully-connected, applied to each position separately and identically, and with a ReLU activation in between. 

$$FFN(x) = max(0, xW_1+b_1)W_2+b_2$$

In [4]:
class FeedForwardNetword(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.l1 = nn.Linear(d_model, d_model)
        self.l2 = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        return self.l2(nn.functional.relu(self.l1(x)))

Now implement a simple unit test for the FFN. This test verifies gradient flows through the FFN

In [25]:
def ffn_unit_test():
    from vit.encoder import FeedForwardNetword
    ffn = FeedForwardNetword(512).to('cpu') #[d_model]
    dummy_inputs = torch.randn((2, 512)).to('cpu') #[batch_size, d_model]
    y = ffn(dummy_inputs)
    loss = y.mean()
    loss.backward()
    for name, param in ffn.named_parameters():
        assert param.grad is not None

In [26]:
ffn_unit_test()

According to the paper, a residual connection is employed around each sub-layer, followed by a layer normalization. That is, the output of each sub-layer is

$$LayerNorm(x + Sublayer(x))$$


In [17]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_head, d_k):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_head, d_k)
        self.ffn = FeedForwardNetword(d_model)
        self.layer_norm = nn.LayerNorm(d_model)
    
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        # attentions_outputs = self.self_attention(x)
        outputs = self.self_attention(x)
        # outputs = attentions_outputs['outputs']
        outputs = self.layer_norm(x + outputs)
        outputs = self.layer_norm(x + self.ffn(x))
        return outputs


In [5]:
encoder_layer = EncoderLayer(3,5,7)

In [6]:
dummy_input = torch.randn(2, 11, 3) #[batch_size, seq_len, d_model]

In [7]:
outputs = encoder_layer(dummy_input)

In [8]:
assert outputs.shape == (2, 11, 3)

In [22]:
def encoder_layer_unit_test():
    from vit.encoder import EncoderLayer
    encoder_layer = EncoderLayer(512, 8, 64).to('cpu') #[d_model, num_head, d_k]
    dummy_inputs = torch.randn((2, 128, 512)).to('cpu') #[batch_size, seq_len, d_model]
    y = encoder_layer(dummy_inputs)
    loss = y.mean()
    loss.backward()
    for name, param in encoder_layer.named_parameters():
        assert param.grad is not None

In [23]:
encoder_layer_unit_test()

# Encoder Stack

Now let's stack encoder layers together to make the encoder stack! According to the paper, the encoder is composed of 6 identical layers.

In [9]:
class EncoderStack(nn.Module):
    def __init__(self, num_layer, d_model, num_head, d_k):
        super().__init__()
        self.num_layer = num_layer
        self.d_model = d_model
        self.num_head = num_head
        self.d_k = d_k
        self.layers = nn.ModuleList([EncoderLayer(d_model, num_head, d_k) for i in range(num_layer)])

    def forward(self, x):
        for i, layer, in enumerate(self.layers):
            x = layer(x)
        return x

In [10]:
encoder = EncoderStack(6, 3,5,7)

In [11]:
encoder_outputs = encoder(dummy_input)

In [12]:
encoder_outputs.shape

torch.Size([2, 11, 3])

In [27]:
def encoder_unit_test():
    from vit.encoder import EncoderStack
    encoder = EncoderLayer(512, 8, 64).to('cpu') #[d_model, num_head, d_k]
    dummy_inputs = torch.randn((2, 128, 512)).to('cpu') #[batch_size, seq_len, d_model]
    y = encoder(dummy_inputs)
    loss = y.mean()
    loss.backward()
    for name, param in encoder.named_parameters():
        assert param.grad is not None

In [None]:
encoder_unit_test()