<a href="https://colab.research.google.com/github/kapilsh/gpt-oss-scratch/blob/main/gpt_oss_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RMSNorm(nn.Module):
  def __init__(self, embedding_dimension: int, eps: float=1e-5):
    super().__init__()
    self.eps = eps
    self.embedding_dimension = embedding_dimension
    self.weight = nn.Parameter(torch.ones(embedding_dimension))

  def forward(self, x: torch.Tensor):
    means = x.pow(2).mean(dim=-1, keepdim=True)
    return (x * torch.rsqrt(means + self.eps)) * self.weight

In [3]:
torch.manual_seed(123)

example_batch = torch.randn(2, 3, 4)

rms_norm = RMSNorm(embedding_dimension=example_batch.shape[-1])
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)

assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))

# How it works

- Take the means of squared x over the last dimension = `means`
- Divide all x's by `sqrt(e + means)` = `normalized_x`
- Multiply `normalized_x` by `weights`: which is a trainable parameter

Let's see an example

In [6]:
print("Actual batch")
print(example_batch)
print("-" * 25)
print("Squared batch")
powered_x = example_batch.pow(2)

print(powered_x)

print("-" * 25)
print("Means of powered x")
means = powered_x.mean(dim=-1, keepdim=True)

print(means)
print("-" * 25)
print("Reciprocal Root mean squared means")
rsqrt_means = torch.rsqrt(means + 1e-5)

print(rsqrt_means)
print("-" * 25)
print("Normalized batch")
normalized_x = example_batch * rsqrt_means

print(normalized_x)
print("-" * 25)
print("Weights")
weights = torch.ones(example_batch.size(-1))

print(weights)
print("-" * 25)
print("Final batch")
final_batch = normalized_x * rms_norm.weight

print(final_batch)

Actual batch
tensor([[[ 0.3374, -0.1778, -0.3035, -0.5880],
         [ 0.3486,  0.6603, -0.2196, -0.3792],
         [-0.1606, -0.4015,  0.6957, -1.8061]],

        [[ 1.8960, -0.1750,  1.3689, -1.6033],
         [-0.7849, -1.4096, -0.4076,  0.7953],
         [ 0.9985,  0.2212,  1.8319, -0.3378]]])
-------------------------
Powered batch
tensor([[[0.1138, 0.0316, 0.0921, 0.3458],
         [0.1215, 0.4361, 0.0482, 0.1438],
         [0.0258, 0.1612, 0.4840, 3.2618]],

        [[3.5947, 0.0306, 1.8739, 2.5705],
         [0.6160, 1.9869, 0.1662, 0.6325],
         [0.9971, 0.0490, 3.3558, 0.1141]]])
-------------------------
Means of powered x
tensor([[[0.1458],
         [0.1874],
         [0.9832]],

        [[2.0174],
         [0.8504],
         [1.1290]]])
-------------------------
Reciprocal Root mean squared means
tensor([[[2.6186],
         [2.3100],
         [1.0085]],

        [[0.7040],
         [1.0844],
         [0.9411]]])
-------------------------
Normalized batch
tensor([[[ 0.8