# Implement RMS Normalization

### Problem Statement

**Root Mean Square Layer Normalization (RMSNorm)** is a simplification of Layer Normalization that has become the standard in modern LLMs like LLaMA, Mistral, and Gemma. Your task is to implement RMSNorm from scratch using PyTorch.

---

### Background

Unlike LayerNorm, RMSNorm **does not center the inputs** (no mean subtraction). It only normalizes by the root mean square, making it computationally cheaper while maintaining similar performance.

#### Mathematical Formula

For an input vector $x \in \mathbb{R}^d$:

$$
\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma
$$

Where:
- $\text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + \epsilon}$
- $\gamma$ is a learnable scale parameter (initialized to ones)
- $\epsilon$ is a small constant for numerical stability

#### Why RMSNorm?

1. **Faster**: No mean computation required
2. **Simpler**: Fewer operations than LayerNorm
3. **Effective**: Empirically works just as well for LLMs

---

### Requirements

1. **Define an `RMSNorm` class** that:
   - Inherits from `nn.Module`
   - Has a learnable scale parameter `gamma` of shape `(dim,)`
   - Applies RMS normalization along the last dimension

2. **Handle arbitrary batch dimensions**:
   - Input shape: `(..., dim)` where `...` represents any number of batch dimensions
   - Output shape: same as input

---


In [None]:
import torch
import torch.nn as nn


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-8):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))  # gamma

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: shape (..., dim)
        norm = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)  # RMS
        return (x / norm) * self.scale


In [None]:
x = torch.randn(3, 5)  # e.g., (batch_size=3, features=5)
rmsnorm = RMSNorm(dim=5)
out = rmsnorm(x)
print(out.shape)  # should be (3, 5)
assert out.shape == (3, 5), "Output shape mismatch"