# Decoder in Attention is All You Need (AAYN)

### Step 1: Implementationg of Decoder Block


### Ref: [The AiEdge Newsletter](https://drive.google.com/file/d/1Je2SAFBlsWcgwzK_gl1_f-LtPK3SOzg3/view)

</br>

**Decoder Block**
<p float="center">
  <img src="../../assets/decoder_block.png" width="700" height="350"> 
</p>


</br>

**Cross Attention**
<p float="center">
  <img src="../../assets/cross_attention.png" width="650" height="300"> 
</p>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# for decoder block we need 'PositionwiseFeedForward' module (positionwise_feed_forward.ipynb)

class PositionwiseFeedForward(nn.Module):
    """
    Implementing position-wise feed forward network

    Args:
        d_model (int): internal dimension of the model or dimension of embeddings.
        (also known as 'hidden_size')
        d_ff (int): dimension of feed-forward network (usually larger than d_model)
    """    
    def __init__(self, d_model: int, d_ff: int) -> None:
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        self.W1 = nn.Linear(d_model, d_ff)
        self.W2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        forward pass of the feed forward network
        Args:
            x (torch.Tensor): input tensor of shape [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: output tensor of shape [batch_size, seq_len, d_model]
        """        
        x = self.W1(x)      # [batch_size, seq_len, d_ff]
        x = self.relu(x)    # add non-linearity
        x = self.W2(x)      # [batch_size, seq_len, d_model]
        return x

In [3]:
class DecoderBlock(nn.Module):
    """
    Transformer decoder block consists of 
    1. multi-head self attention + residual connection -> normalization layer
    2. cross attention + residual connection -> normalization layer
    3. position-wise feed foward + residual connection -> normalization layer

    Args:
        d_model (int): internal dimension of the model or dimension of embeddings.
        (also known as 'hidden_size')
        n_head (int): number of attention heads
        d_ff (int): dimension of feed-forward network
    """    
    def __init__(self, d_model: int, n_head: int, d_ff: int):
        super().__init__()

        # self-attention: decoder attends to previous positions
        self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, batch_first=True)
        
        # cross attention: decoder attends to encoder output
        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, batch_first=True)
        
        # position-wise feed forward network
        self.pos_feed_forward = PositionwiseFeedForward(d_model, d_ff)
        
        # layer normalization after each residual connection
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
    
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor) -> torch.Tensor:
        """
        forward pass of decoder block

        Args:
            x (torch.Tensor): (target) input embedding of shape [batch_size, target_seq_len, d_model]
            encoder_output (torch.Tensor): Encoder output of shape [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: output tensor of shape [batch_size, target_seq_len, d_model]
        """
        # 1. self-attention + residual connection that is passed to normalization layer    
        attn_output, _ = self.self_attn(query=x, key=x, value=x)    # [batch_size, target_seq_len, d_model]
        x = self.norm1(attn_output + x)

        # 2. corss attention + residual connection that is passed to normalization layer
        cross_att_output, _ = self.cross_attn(query=x, key=encoder_output, value=encoder_output)   # [batch_size, target_seq_len, d_model]
        x = self.norm2(cross_att_output + x)

        # 3. position-wise feed forward + residual connection that is passed to normalization layer
        ff_output = self.pos_feed_forward(x)    # [batch_size, target_seq_len, d_model]
        x = self.norm3(ff_output + x)

        return x

#### Toy Example: Testing Decoder Block

In [4]:
batch_size = 2
seq_len = 5             # length of encoder sequence
target_seq_len = 7      # length of decoder sequence

d_model = 12            # or model dim or hidden size
n_head = 3              # number of attention heads
d_ff = 48               # dimension of positional feed forward layer

# generate random dummy tensors for input of decoder block
decoder_input = torch.randn(batch_size, target_seq_len, d_model)
encoder_output = torch.randn(batch_size, seq_len, d_model)

decoder_block = DecoderBlock(d_model=d_model, n_head=n_head, d_ff=d_ff)
decoded = decoder_block(decoder_input, encoder_output)

print(f'encoder_output: {encoder_output.size()}, and decoder_input: {decoder_input.size()}, and deocoder block output: {decoded.size()}')

encoder_output: torch.Size([2, 5, 12]), and decoder_input: torch.Size([2, 7, 12]), and deocoder block output: torch.Size([2, 7, 12])


### Step 2: Implementation of Decoder

### Ref: [The AiEdge Newsletter](https://drive.google.com/file/d/1Je2SAFBlsWcgwzK_gl1_f-LtPK3SOzg3/view)

<p float="center">
  <img src="../../assets/decoder.png" width="750" height="400">
</p>

In [5]:
# For implementing the decoder, PositionalEncoding is required (positional_encoding.ipynb)

class PositionalEncoding(nn.Module):
    """
    simple positional encoding with transformer model (in attention is all you need)

    Args:
        context_size (int): maximum lenght of the input sequence (also known as max_length)
        d_model (int): internal dimension of the model or dimension of embeddings.
        (also known as 'hidden_size')
    """    
    def __init__(self, context_size: int, d_model: int):
        super().__init__()

        pos = torch.arange(0, context_size).unsqueeze(dim=1) # [context_size, 1]
        # dimension indices
        # for d_model=5 -> ii = (0, 2, 4) and ii[:d_model//2] = (0, 2) (see the figure above)
        # this way of implementation, will cover both even and odd values for d_model
        i = torch.arange(0, (d_model + 1) // 2)
        div_term = 10000 ** (2 * i / d_model)
        
        # initialize positional encoding [context_size, d_model]
        self.encoding = torch.zeros(context_size, d_model)
        self.encoding[:, 0::2] = torch.sin(pos / div_term) # even positions
        self.encoding[:, 1::2] = torch.cos(pos / div_term[:d_model // 2]) # odd  positions

        # Registers positional encoding tensor as part of the module state, but not as a 
        # learnable parameter (i.e., not updated by gradient descent). Positional encodings 
        # in the vanilla “Attention Is All You Need” is not trained.
        # Moreover, when register as buffer, it moves with model and is saved in state_dict
        self.register_buffer('pos_encoding', self.encoding)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        returns positional encoding for a given input tensor x.
        (input tensor x is comming from token embedding layer in the transformer architecture)

        Args:
            x (torch.Tensor): input tensor [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: positional encoding slice of shape [seq_len, d_model], 
            ready to be added to token embeddings.
        """        
        seq_len = x.size(1) # number of tokens in the input sequence
        
        # make sure to use 'pos_encoding' from self.register_buffer 
        # (otherwise we can't use the benefits of self.register_buffer)
        return self.pos_encoding[:seq_len, :]

In [None]:
class Decoder(nn.Module):
    """
    Transformer Decoder that consists of 
    1. token embedding
    2. positional embedding
    3. N stacked encoder blocks
    4. Final linear projection to output vocabulary

    Args:
        output_size (int): size of the vocabulary (target vocab size).
        context_size (int): max sequence lenght (for positional encoding)
        d_model (int): embedding dimension
        n_head (int): number of attention heads
        d_ff (int): dimension of positional feed forward network
        n_block (int): number of stacked decoder blocks
    """    
    def __init__(self, output_size: int, context_size: int, d_model: int, n_head: int, d_ff: int, n_block: int):
        super().__init__()

        self.embedding = nn.Embedding(output_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)
        
        # stacked decoder blocks
        self.blocks = nn.ModuleList([
            DecoderBlock(d_model, n_head, d_ff) for _ in range(n_block)
        ])

        # final linear projection (to target vocabulary size)
        self.out = nn.Linear(d_model, output_size)

    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor):
        """
        decoder forward pass

        Args:
            x (torch.Tensor): target token indices of shape [batch_size, target_seq_len]. 
            Each element is an integer token id
            encoder_output (torch.Tensor): output of the encoder of shape [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: decoder output of shape [batch_size, target_seq_len, output_size] where output_size is 
            the target vocabulary size.
        """
        embedded = self.embedding(x)            # [batch_size, target_seq_len, d_model]
        pos_embedded = self.pos_embedding(x)    # [target_seq_len, d_model]
        x = embedded + pos_embedded             # broadcast -> [batch_size, target_seq_len, d_model]
        
        for block in self.blocks:
            x = block(x, encoder_output)        # [batch_size, target_seq_len, d_model]
        
        output = self.out(x)    # [batch_size, target_seq_len, output_size]
        return output

#### Toy Example: Testing Decoder Module

In [7]:
batch_size = 2
seq_len = 5             # length of encoder sequence
target_seq_len = 7      # length of decoder sequence
vocab_size = 100

d_model = 12            # or model dim or hidden size
n_head = 3              # number of attention heads
d_ff = 48               # dimension of positional feed forward layer
n_block = 5             # number of stacked decoder block


# generate random dummy tensors for input of decoder block
target_input = torch.randint(0, vocab_size, (batch_size, target_seq_len))
encoder_output = torch.randn(batch_size, seq_len, d_model)

decoder = Decoder(output_size=vocab_size, context_size=target_seq_len, d_model=d_model, n_head=n_head, d_ff=d_ff, n_block=n_block)
decoded = decoder(target_input, encoder_output)

print(f'encoder_output: {encoder_output.size()}, and decoder_input: {target_input.size()}, and deocoder block output: {decoded.size()}')

encoder_output: torch.Size([2, 5, 12]), and decoder_input: torch.Size([2, 7]), and deocoder block output: torch.Size([2, 7, 100])


In [8]:
decoder

Decoder(
  (embedding): Embedding(100, 12)
  (pos_embedding): PositionalEncoding()
  (blocks): ModuleList(
    (0-4): 5 x DecoderBlock(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
      )
      (cross_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
      )
      (pos_feed_forward): PositionwiseFeedForward(
        (W1): Linear(in_features=12, out_features=48, bias=True)
        (W2): Linear(in_features=48, out_features=12, bias=True)
        (relu): ReLU()
      )
      (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
      (norm3): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
    )
  )
  (out): Linear(in_features=12, out_features=100, bias=True)
)