GPT ARCHITECTURE PART 2: LAYER NORMALIZATION

Explanation with a simple example

In [2]:
import torch
import torch.nn as nn

torch.manual_seed(123)
batch_example = torch.randn(2, 5)
layer = nn.Sequential(nn.Linear(5, 6), nn.ReLU())
out = layer(batch_example)
print(out)

tensor([[0.2260, 0.3470, 0.0000, 0.2216, 0.0000, 0.0000],
        [0.2133, 0.2394, 0.0000, 0.5198, 0.3297, 0.0000]],
       grad_fn=<ReluBackward0>)


The neural network layer we have coded consists of a Linear layer followed by a non-linear activation standard activation function in neural networks.

If you are unfamiliar with ReLU, it simply thresholds negative inputs to 0, ensuring that a layer outputs output does not contain any negative values.

(Note that we will use another, more sophisticated activation function in GPT, which we will introduce

Before we apply layer normalization to these outputs, let's examine the mean and variance:

In [3]:
mean = out.mean(dim=-1, keepdim=True)
var = out.var(dim=-1, keepdim=True)
print("Mean:", mean)
print("Variance:", var)

Mean: tensor([[0.1324],
        [0.2170]], grad_fn=<MeanBackward1>)
Variance: tensor([[0.0231],
        [0.0398]], grad_fn=<VarBackward0>)


This first row in the mean tensor above contains the mean value for the first input row, and second row contains the mean value for the second input row. The same applies to the variance tensor.

In [4]:
out_norm = (out - mean) / torch.sqrt(var)
mean = out_norm.mean(dim=-1, keepdim=True)
var = out_norm.var(dim=-1, keepdim=True)
print("Normalized layer outputs:\n", out_norm)
print("Mean: \n", mean)
print("Variance: \n", var)

Normalized layer outputs:
 tensor([[ 0.6159,  1.4126, -0.8719,  0.5872, -0.8719, -0.8719],
        [-0.0189,  0.1121, -1.0876,  1.5173,  0.5647, -1.0876]],
       grad_fn=<DivBackward0>)
Mean: 
 tensor([[9.9341e-09],
        [0.0000e+00]], grad_fn=<MeanBackward1>)
Variance: 
 tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)


Note that the value 2.9802e-08 in the output tensor is the scientific notation for  $2.9802 \times 10^{-8}$ , which is 0.0000000298 in decimal form. This value is very close to 0, but it is not exactly 0 due to small numerical errors that can accumulate because of the finite precision with which computers represent numbers.

To improve readability, we can also turn off the scientific notation when printing tensor values by setting sci_mode to False:

In [5]:
torch.set_printoptions(sci_mode=False)
print("Mean:", mean)
print("Variance:", mean)

Mean: tensor([[    0.0000],
        [    0.0000]], grad_fn=<MeanBackward1>)
Variance: tensor([[    0.0000],
        [    0.0000]], grad_fn=<MeanBackward1>)


Let's now encapsulate this process in a PyTorch module that we can use in the GPT model later:   

In [6]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

This specific implementation of layer Normalization operates on the last dimension of the input tensor  $\mathbf{x}$ , which represents the embedding dimension (emb_dim).

The variable eps is a small constant (epsilon) added to the variance to prevent division by zero during normalization.

The scale and shift are two trainable parameters (of the same dimension as the input) that the LLM automatically adjusts during training if it is determined that I doing so would improve the model's performance on its training task.

This allows the model to learn appropriate scaling and shifting that best suit the data it is processing.

A small note on biased variance

In our variance calculation method, we have opted for an implementation detail by setting unbiased  $=$  False.

For those curious about what this means, in the variance calculation, we divide by the number of inputs n in the variance formula.

This approach does not apply Bessel's correction, which typically uses n-1 instead of n in the denominator to adjust for bias in sample variance estimation.

This decision results in a so-called biased estimate of the variance.

For large-scale language models (LLMs), where the embedding dimension n is significantly large, the difference between using n and n-1 is practically negligible.

We chose this approach to ensure compatibility with the GPT-2 model's normalization layers and because it reflects TensorFlow's default behavior, which was used to implement the original GPT2 model.

Let's now try the LayerNorm module in practice and apply it to the batch input:

In [7]:
ln = LayerNorm(emb_dim=5)
out_ln=ln(batch_example)
mean = out_ln.mean(dim=-1, keepdim=True)
var = out_ln.var(dim=-1, unbiased=False, keepdim=True)
print("mean:", mean)
print("var:", var)

mean: tensor([[    -0.0000],
        [     0.0000]], grad_fn=<MeanBackward1>)
var: tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)


In [8]:
class GELU(nn.Module):
	def __init__(self):
		super().__init__()
		
	def forward(self, x):
		return 0.5 * x * (1 + torch.tanh(
			torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
			(x + 0.044715 * torch.pow(x, 3))
		)) 

In [14]:
class FeedForward(nn.Module):
	def __init__(self, cfg):
		super().__init__()
		self.layers = nn.Sequential(
			nn.Linear(cfg["emb_dim"], 4*cfg["emb_dim"]),
			GELU(),
			nn.Linear(4*cfg["emb_dim"], cfg["emb_dim"]),
		)
	
	def forward(self, x):
		return self.layers(x)

In [16]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,  # Vocabulary size
    "context_length": 1024,  # Context length
    "emb_dim": 768,  # Embedding dimension
    "n_heads": 12,  # Number of attention heads
    "n_layers": 12,  # Number of layers
    "drop_rate": 0.1,  # Dropout rate
    "qkv.bias": False,  # Query-KEY-value bias
}

ffn = FeedForward(GPT_CONFIG_124M)
x = torch.rand(2, 3, 768)
out = ffn(x)
print(out.shape)  # Expected output shape: (2, 3, 768)

torch.Size([2, 3, 768])
