In [8]:
import torch.nn as nn
import torch

In [1]:
GPT_CONFIG_124M = {
 "vocab_size": 50257,      # Vocabulary size
 "context_length": 1024,   # Context length
 "emb_dim": 768,           # Embedding dimension
 "n_heads": 12,            # Number of attention heads
 "n_layers": 12,           # Number of layers
 "drop_rate": 0.1,         # Dropout rate
 "qkv_bias": False         # Query-Key-Value bias
}

In [3]:
import sys
sys.path.append('../')

In [10]:
from modular.AttentionMechanisms.MultiHeadAttention import MultiHeadAttention
from modular.GPT_architecture.FeedForwardBlock import FeedForward
from modular.GPT_architecture.LayerNormalization import LayerNorm

Let's code the transformer block as follows:
* shortcut connection for attention block
* shortcut connection for feed forward block
* adding the original input back

In [11]:
class TransformerBlock(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in = cfg['emb_dim'],
            d_out = cfg['emb_dim'],
            context_length=cfg['context_length'],
            num_heads=cfg['n_heads'],
            dropout=cfg['drop_rate'],
            qkv_bias=cfg['qkv_bias']
        )
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg['emb_dim'])
        self.norm2 = LayerNorm(cfg['emb_dim'])
        self.drop_shortcut = nn.Dropout(cfg['drop_rate'])
        
    def forward(self,x):
        # shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x) #shape = [batch_size,num_tokens,emb_size]
        x = self.drop_shortcut(x)
        x = x + shortcut # add the original input back 
        
        # shortcut connection for the feed forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut # add the original input back
        
        return x

We can test the transformer block

In [13]:
torch.manual_seed(123)
x = torch.rand(2,4,768)
block = TransformerBlock(GPT_CONFIG_124M)
output = block(x)
print('Input shape',x.shape)
print('Output shape',output.shape)

Input shape torch.Size([2, 4, 768])
Output shape torch.Size([2, 4, 768])


* The transformer block maintains the input dimensions in its output, indicating that the transformer architecture processes sequences of data without altering their shape throughout the network.