In [1]:
from torch import nn
import torch
import torch.nn.functional as F

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by n_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(-2, -1)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec
    
    
# test
torch.manual_seed(3333)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

mha = MultiHeadAttention(d_in=C, d_out=C, context_length=T, dropout=0.1, num_heads=4)
mha(x).shape
        


torch.Size([4, 8, 32])

In [3]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) *
            (x + 0.044715 * torch.pow(x, 3))
        ))
        
# test
torch.manual_seed(3333)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

gelu = GELU()
gelu(x).shape

torch.Size([4, 8, 32])

In [4]:
class FeedForward(nn.Module):
    """ 
    Applies a feed-forward network to each token in the sequence independently.
    """
    def __init__(self, embed_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            # nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim),
        )
        
    def forward(self, x):
        return self.net(x)
    
# test 
torch.manual_seed(3333)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

ff = FeedForward(embed_dim=C, dropout=0.1)
ff(x).shape

torch.Size([4, 8, 32])

In [5]:
class LayerNorm(nn.Module):
    """
    Every hidden unit is normalized based on the mean and std of the hidden units in the same layer.
    
    In contrast to BatchNorm, where the normalization is done based on input samples in the same batch.
    """
    
    def __init__(self, embed_dim, eps=1e-5):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(embed_dim))
        self.shift = nn.Parameter(torch.zeros(embed_dim))
        self.eps = eps
        
    def forward(self, x):
        # compute mean and std accross the last dimension (embed_dim) 
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        norm = (x - mean) / (std + self.eps) # normalizing the input
        return norm * self.scale + self.shift
    
# test
torch.manual_seed(3333)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

ln = LayerNorm(32)
out = ln(x)

# the output should result in each individual token embedding having mean 0 and std 1
out.mean(-1), out.std(-1)

(tensor([[ 3.7253e-09, -7.4506e-09,  1.8626e-09, -4.4703e-08, -1.1176e-08,
           7.4506e-09,  0.0000e+00,  1.1176e-08],
         [ 7.4506e-09,  2.9802e-08,  1.8626e-08, -7.4506e-09, -1.1176e-08,
           1.1176e-08,  3.7253e-09, -1.6764e-08],
         [ 7.4506e-09,  1.4901e-08,  2.9802e-08, -1.4901e-08, -2.9802e-08,
          -9.3132e-09,  2.2352e-08, -2.0489e-08],
         [ 2.2352e-08, -2.2352e-08, -1.8626e-08, -3.3528e-08, -7.4506e-09,
           1.1176e-08, -7.4506e-09, -7.4506e-09]], grad_fn=<MeanBackward1>),
 tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
        grad_fn=<StdBackward0>))

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        # Shortcut connection for attention block
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)   # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        # Shortcut connection for feed-forward block
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut  # Add the original input back

        return x
    
# test
torch.manual_seed(3333)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

cfg = {
    "emb_dim": 32,
    "context_length": 8,
    "n_heads": 2,
    "drop_rate": 0.1,
    "qkv_bias": False
}

tb = TransformerBlock(cfg)
out = tb(x)
out.shape

TypeError: FeedForward.__init__() missing 1 required positional argument: 'dropout'

In [None]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["dropout_rate"])

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])

        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx):
        B, T = in_idx.shape
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(T, device=in_idx.device))
        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits
    
# test
torch.manual_seed(3333)
x = torch.tensor([[1,2,3,4,5,6,7,8]]) # B=1, T=8

cfg = {
    "emb_dim": 32,
    "seq_len": 8,
    "num_heads": 2,
    "dropout_rate": 0.1,
    "qkv_bias": False,
    "context_length": 8,
    "vocab_size": 100,
    "n_layers": 2,
}

model = GPTModel(cfg)
out = model(x)
out.shape # for every of the eight tokens there is a probability distribution over the vocab_size (here: 100) if applying softmax to the logits -> sample from this distribution to get the next token