# Transformer Block Implementation

To build a simple Transformer block and pass the forward input $X$ through it, let's extend the previous example by implementing the full Transformer block. The Transformer block consists of the following components:

* **Self-Attention:** Computes attention scores to attend to different parts of the sequence.
* **Feed-Forward Network:** A fully connected feed-forward network.
* **Layer Normalization:** Applied after the attention and feed-forward network to stabilize training.
* **Residual Connections:** Added around the attention and feed-forward layers.

Below is the implementation of a Transformer block in PyTorch:

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, num_heads, ff_hidden_size, dropout=0.1):
        super(TransformerBlock, self).__init__()
        
        # Multi-Head Attention
        self.attention = nn.MultiheadAttention(embed_size, num_heads, dropout=dropout)
        
        # Feed Forward Network
        self.ffn = nn.Sequential(
            nn.Linear(embed_size, ff_hidden_size),
            nn.ReLU(),
            nn.Linear(ff_hidden_size, embed_size)
        )
        
        # Layer Normalization
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # x is expected to have shape [seq_len, batch_size, embed_size]

        # Attention layer with residual connection
        attn_out, _ = self.attention(x, x, x)  # Self-attention: Q = K = V = x
        x = self.norm1(x + self.dropout(attn_out))  # Add & Norm
        
        # Feed-forward layer with residual connection
        ff_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ff_out))  # Add & Norm
        
        return x


In [3]:
# Example sequence length, batch size, and embedding size
seq_len = 5  # Sequence length (e.g., number of tokens)
batch_size = 3  # Number of samples in the batch
embed_size = 4  # Dimensionality of the embedding

# Random input tensor for X
X = torch.rand(seq_len, batch_size, embed_size)  # Shape: [seq_len, batch_size, embed_size]


In [4]:
# Initialize the Transformer Block
num_heads = 2
ff_hidden_size = 8  # Feed-forward hidden size
transformer_block = TransformerBlock(embed_size, num_heads, ff_hidden_size)

# Pass the input through the Transformer block
output = transformer_block(X)

print(output.shape)  # Output shape should be [seq_len, batch_size, embed_size]


torch.Size([5, 3, 4])
