# Position-wise Feed Forward layer in Attention is All You Need (AAYN)

### Ref: [The AiEdge Newsletter](https://drive.google.com/file/d/1Je2SAFBlsWcgwzK_gl1_f-LtPK3SOzg3/view)

<img src="../../assets/poswise_feedforward.png" width="700" height="200">

Position-wise feed forward network will be used in the Encoder Block as well as Decoder Block. (Encoder block contains two main sub-layers: 1. Multi-head self-attention and 2. position-wise feed forward network while Decoder block contains of three main sub-layers: 1. Multi-head self-attention, 2. Multi-head cross attention, and 3. position-wise feed forward network).

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class PositionwiseFeedForward(nn.Module):
    """
    Implementing position-wise feed forward network

    Args:
        d_model (int): internal dimension of the model or dimension of embeddings.
        (also known as 'hidden_size')
        d_ff (int): dimension of feed-forward network (usually larger than d_model)
    """    
    def __init__(self, d_model: int, d_ff: int) -> None:
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        self.W1 = nn.Linear(d_model, d_ff)
        self.W2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        forward pass of the feed forward network
        Args:
            x (torch.Tensor): input tensor of shape [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: output tensor of shape [batch_size, seq_len, d_model]
        """        
        x = self.W1(x)      # [batch_size, seq_len, d_ff]
        x = self.relu(x)    # add non-linearity
        x = self.W2(x)      # [batch_size, seq_len, d_model]
        return x


#### Toy Example

In [3]:
batch_size = 2
d_model = 12    # or model dim or hidden size
seq_len = 5
d_ff = 24

# generate random dummy input
x = torch.randn(batch_size, seq_len, d_model)

# instantiat the network
pffn = PositionwiseFeedForward(d_model, d_ff)
out = pffn(x) # forward pass

print(f'input x: {x.size()}, and feed forward output: {out.size()}')

input x: torch.Size([2, 5, 12]), and feed forward output: torch.Size([2, 5, 12])
