# Transformer in Attention is All You Need (AAYN)

### Ref: [The AiEdge Newsletter](https://drive.google.com/file/d/1Je2SAFBlsWcgwzK_gl1_f-LtPK3SOzg3/view)

</br>

**Transformer Architecture**
<p float="center">
  <img src="../../assets/trasformer.png" width="600" height="350"> 
</p>


In [1]:
import torch
import torch.nn as nn
import torch.functional as F

#### helper modules

In [2]:
# Necessary modules required for Transformer

class PositionalEncoding(nn.Module):
    """
    simple positional encoding with transformer model (in attention is all you need)

    Args:
        context_size (int): maximum lenght of the input sequence (also known as max_length)
        d_model (int): internal dimension of the model or dimension of embeddings.
        (also known as 'hidden_size')
    """    
    def __init__(self, context_size: int, d_model: int):
        super().__init__()

        pos = torch.arange(0, context_size).unsqueeze(dim=1) # [context_size, 1]
        # dimension indices
        # for d_model=5 -> ii = (0, 2, 4) and ii[:d_model//2] = (0, 2) (see the figure above)
        # this way of implementation, will cover both even and odd values for d_model
        i = torch.arange(0, (d_model + 1) // 2)
        div_term = 10000 ** (2 * i / d_model)
        
        # initialize positional encoding [context_size, d_model]
        self.encoding = torch.zeros(context_size, d_model)
        self.encoding[:, 0::2] = torch.sin(pos / div_term) # even positions
        self.encoding[:, 1::2] = torch.cos(pos / div_term[:d_model // 2]) # odd  positions

        # Registers positional encoding tensor as part of the module state, but not as a 
        # learnable parameter (i.e., not updated by gradient descent). Positional encodings 
        # in the vanilla “Attention Is All You Need” is not trained.
        # Moreover, when register as buffer, it moves with model and is saved in state_dict
        self.register_buffer('pos_encoding', self.encoding)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        returns positional encoding for a given input tensor x.
        (input tensor x is comming from token embedding layer in the transformer architecture)

        Args:
            x (torch.Tensor): input tensor [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: positional encoding slice of shape [seq_len, d_model], 
            ready to be added to token embeddings.
        """        
        seq_len = x.size(1) # number of tokens in the input sequence
        
        # make sure to use 'pos_encoding' from self.register_buffer 
        # (otherwise we can't use the benefits of self.register_buffer)
        return self.pos_encoding[:seq_len, :]


class PositionwiseFeedForward(nn.Module):
    """
    Implementing position-wise feed forward network

    Args:
        d_model (int): internal dimension of the model or dimension of embeddings.
        (also known as 'hidden_size')
        d_ff (int): dimension of feed-forward network (usually larger than d_model)
    """    
    def __init__(self, d_model: int, d_ff: int) -> None:
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        self.W1 = nn.Linear(d_model, d_ff)
        self.W2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        forward pass of the feed forward network
        Args:
            x (torch.Tensor): input tensor of shape [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: output tensor of shape [batch_size, seq_len, d_model]
        """        
        x = self.W1(x)      # [batch_size, seq_len, d_ff]
        x = self.relu(x)    # add non-linearity
        x = self.W2(x)      # [batch_size, seq_len, d_model]
        
        return x






#### Encoder Module

In [3]:
class EncoderBlock(nn.Module):
    """
    single transformer block consists of 
    1. multi-head self attention + residual connection -> normalization layer
    2. position-wise feed foward + residual connection -> normalization layer

    Args:
        d_model (int): internal dimension of the model or dimension of embeddings.
        (also known as 'hidden_size')
        n_head (int): number of attention heads
        d_ff (int): dimension of feed-forward network
    """    
    def __init__(self, d_model: int, n_head: int, d_ff: int) -> None:
        super().__init__()

        self.d_model = d_model
        self.n_head = n_head
        self.d_ff = d_ff

        # here we use multi-head attention impelementation in pytorch
        # see: https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
        self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, batch_first=True)

        self.pos_feed_forward = PositionwiseFeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        forward pass of encoder block

        Args:
            x (torch.Tensor): input embedding of shape [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: encoded tensor of shape [batch_size, seq_len, d_model]
        """        

        # 1. self-attention + residual connection that is passed to normalization layer
        attn_output, _ = self.self_attn(query=x, key=x, value=x) # [batch_size, seq_len, d_model]
        x = self.norm1(x + attn_output)  # [batch_size, seq_len, d_model]

        # 2. position-wise feed forward + residual connection that is passed to normalization layer
        ff_output = self.pos_feed_forward(x) # [batch_size, seq_len, d_model]
        x = self.norm2(x + ff_output)   # [batch_size, seq_len, d_model]
        
        return x


class Encoder(nn.Module):
    """
    Transformer Encoder that consists of 
    1. token embedding
    2. positional embedding
    3. N stacked encoder blocks

    Args:
        input_size (int): vocabulary size
        context_size (int): max sequence lenght (max tokens per input)
        d_model (int): embedding dimension
        n_head (int): number of attention heads
        d_ff (int): dimension of positional feed forward network
        n_block (int): number of stacked encoder blocks
    """    
    def __init__(self, input_size: int, context_size: int, d_model: int, n_head: int, d_ff: int, n_block: int):
        super().__init__()

        self.embedding = nn.Embedding(input_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)

        # stacked encoder blocks. 
        # nn.ModuleList makes sure that all layers are as part of the model that means
        # they show up in model.parameters(), they are included in model-state_dict (for saving and loading)
        # and the get updated during backpropagation (optimzer.step()), hence can be parallize later as well.
        self.blocks = nn.ModuleList([
            EncoderBlock(d_model, n_head, d_ff) for _ in range(n_block)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        encoder forward pass

        Args:
            x (torch.Tensor): input token indices of shape [batch_size, seq_len].
            each element is an integer token id.

        Returns:
            torch.Tensor: output of encoder of shape [batch_size, seq_len, d_model].
        """        
        embedded = self.embedding(x)    # [batch_size, seq_len, d_model]
        pos_embedded = self.pos_embedding(x)    # [seq_len, d_model]
        x = embedded + pos_embedded     # broadcast -> [batch_size, seq_len, d_model]
        for block in self.blocks:
            x = block(x)
        
        return x



#### Decoder Module

In [4]:
class DecoderBlock(nn.Module):
    """
    Transformer decoder block consists of 
    1. multi-head self attention + residual connection -> normalization layer
    2. cross attention + residual connection -> normalization layer
    3. position-wise feed foward + residual connection -> normalization layer

    Args:
        d_model (int): internal dimension of the model or dimension of embeddings.
        (also known as 'hidden_size')
        n_head (int): number of attention heads
        d_ff (int): dimension of feed-forward network
    """    
    def __init__(self, d_model: int, n_head: int, d_ff: int):
        super().__init__()

        # self-attention: decoder attends to previous positions
        self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, batch_first=True)
        
        # cross attention: decoder attends to encoder output
        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, batch_first=True)
        
        # position-wise feed forward network
        self.pos_feed_forward = PositionwiseFeedForward(d_model, d_ff)
        
        # layer normalization after each residual connection
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
    
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor) -> torch.Tensor:
        """
        forward pass of decoder block

        Args:
            x (torch.Tensor): (target) input embedding of shape [batch_size, target_seq_len, d_model]
            encoder_output (torch.Tensor): Encoder output of shape [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: output tensor of shape [batch_size, target_seq_len, d_model]
        """
        # 1. self-attention + residual connection that is passed to normalization layer    
        attn_output, _ = self.self_attn(query=x, key=x, value=x)    # [batch_size, target_seq_len, d_model]
        x = self.norm1(attn_output + x)

        # 2. corss attention + residual connection that is passed to normalization layer
        cross_att_output, _ = self.cross_attn(query=x, key=encoder_output, value=encoder_output)   # [batch_size, target_seq_len, d_model]
        x = self.norm2(cross_att_output + x)

        # 3. position-wise feed forward + residual connection that is passed to normalization layer
        ff_output = self.pos_feed_forward(x)    # [batch_size, target_seq_len, d_model]
        x = self.norm3(ff_output + x)

        return x


class Decoder(nn.Module):
    """
    Transformer Decoder that consists of 
    1. token embedding
    2. positional embedding
    3. N stacked encoder blocks
    4. Final linear projection to output vocabulary

    Args:
        output_size (int): size of the vocabulary (target vocab size).
        context_size (int): max sequence lenght (for positional encoding)
        d_model (int): embedding dimension
        n_head (int): number of attention heads
        d_ff (int): dimension of positional feed forward network
        n_block (int): number of stacked decoder blocks
    """    
    def __init__(self, output_size: int, context_size: int, d_model: int, n_head: int, d_ff: int, n_block: int):
        super().__init__()

        self.embedding = nn.Embedding(output_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)
        
        # stacked decoder blocks
        self.blocks = nn.ModuleList([
            DecoderBlock(d_model, n_head, d_ff) for _ in range(n_block)
        ])

        # final linear projection (to target vocabulary size)
        self.out = nn.Linear(d_model, output_size)

    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor):
        """
        decoder forward pass

        Args:
            x (torch.Tensor): target token indices of shape [batch_size, target_seq_len]. 
            Each element is an integer token id
            encoder_output (torch.Tensor): output of the encoder of shape [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: decoder output of shape [batch_size, target_seq_len, output_size] where output_size is 
            the target vocabulary size.
        """        
        embedded = self.embedding(x)            # [batch_size, target_seq_len, d_model]
        pos_embedded = self.pos_embedding(x)    # [target_seq_len, d_model]
        x = embedded + pos_embedded             # broadcast -> [batch_size, target_seq_len, d_model]
        
        for block in self.blocks:
            x = block(x, encoder_output)        # [batch_size, target_seq_len, d_model]
        
        output = self.out(x)    # [batch_size, target_seq_len, output_size]
        
        return output

### Transformer

In [5]:
class VanillaTransformer(nn.Module):
    """
    Vanilla Transofrmer based on the architecture from 'Attention is all you need'.

    Args:
        input_size (int): vocabulary size of the source language (input to the transformer)
        output_size (int): vocabulary size of the target language (output of the transformer)
        context_size (int): max sequence lenght (max tokens per input)
        d_model (int): embedding dimension
        n_head (int): number of attention heads
        d_ff (int): dimension of positional feed forward network
        n_block (int): number of stacked encoder blocks
    """    
    def __init__(self, input_size: int, output_size: int, context_size: int, d_model: int, n_head: int, d_ff: int, n_block: int) -> None:
        super().__init__()

        # Encoder: process the source (input) sequence
        self.encoder = Encoder(
            input_size=input_size, 
            context_size=context_size, 
            d_model=d_model, 
            n_head=n_head, 
            d_ff=d_ff, 
            n_block=n_block
        )

        # Decoder: generates target (output) sequence considering encoder output
        self.decoder = Decoder(
            output_size=output_size, 
            context_size=context_size, 
            d_model=d_model, 
            n_head=n_head, 
            d_ff=d_ff, 
            n_block=n_block
        )

    
    def forward(self, x_encoder: torch.Tensor, x_decoder: torch.Tensor) -> torch.Tensor:
        """
        forward pass of vanilla transformer

        Args:
            x_encoder (torch.Tensor): input token indices of shape [batch_size, seq_len]
            x_decoder (torch.Tensor): output token indices of shape [batch_size, target_seq_len]

        Returns:
            torch.Tensor: Decoder output logits of shape [batch_size, target_seq_len, output_size].
            These logits can be passed later to 'nn.CrossEntropyLoss' during training or to a 
            'softmax' layer for extracting probability of each token in vocabulary.
        """        
        encoder_output = self.encoder(x_encoder)    # [batch_size, seq_len, d_model]
        decoder_output = self.decoder(x_decoder, encoder_output)    # [batch_size, target_seq_len, output_size]

        return decoder_output

#### Toy Example: Testing Transformer

In [6]:
# ==========================
# Hyperparameters
# ==========================

batch_size = 2
vocab_size = 100
seq_len = 10             # length of encoder sequence
target_seq_len = 12      # length of decoder sequence

d_model = 12            # or model dim or hidden size or embedding dim
n_head = 3              # number of attention heads
d_ff = 48               # dimension of positional feed forward layer
n_block = 5             # number of stacked decoder block

context_size = max(seq_len, target_seq_len)  # Max length (used for positional encoding in both encoder & decoder)

# ==========================
# Inputs
# ==========================
# generate random Token Ids (note that Embedding layer only receives integer values)
encoder_x = torch.randint(0, vocab_size, (batch_size, seq_len))
decoder_x = torch.randint(0, vocab_size, (batch_size, target_seq_len))

# ==========================
# Transformer & Forward Pass
# ==========================
transformer = VanillaTransformer(
    input_size=vocab_size, 
    output_size=vocab_size, 
    context_size=context_size,
    d_model=d_model, 
    n_head=n_head, 
    d_ff=d_ff, 
    n_block=n_block
)
x_transformed = transformer(encoder_x, decoder_x)


# ==========================
# Assertions
# ==========================
# Transformer output should be [batch_size, target_seq_len, output_size]
expected_output_shape = (batch_size, target_seq_len, vocab_size)
assert x_transformed.shape == expected_output_shape, f"Expected output shape {expected_output_shape}, got {x_transformed.shape}"

print("Transformer forward pass successful!")
print(f"encoder_x: {encoder_x.shape}, decoder_x: {decoder_x.shape}, output: {x_transformed.shape}")

Transformer forward pass successful!
encoder_x: torch.Size([2, 10]), decoder_x: torch.Size([2, 12]), output: torch.Size([2, 12, 100])


In [7]:
transformer

VanillaTransformer(
  (encoder): Encoder(
    (embedding): Embedding(100, 12)
    (pos_embedding): PositionalEncoding()
    (blocks): ModuleList(
      (0-4): 5 x EncoderBlock(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (pos_feed_forward): PositionwiseFeedForward(
          (W1): Linear(in_features=12, out_features=48, bias=True)
          (W2): Linear(in_features=48, out_features=12, bias=True)
          (relu): ReLU()
        )
        (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (decoder): Decoder(
    (embedding): Embedding(100, 12)
    (pos_embedding): PositionalEncoding()
    (blocks): ModuleList(
      (0-4): 5 x DecoderBlock(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True