# Implement RMS Normalization

### Problem Statement

**Root Mean Square Layer Normalization (RMSNorm)** is a simplification of Layer Normalization that has become the standard in modern LLMs like LLaMA, Mistral, and Gemma. Your task is to implement RMSNorm from scratch using PyTorch.

---

### Background

Unlike LayerNorm, RMSNorm **does not center the inputs** (no mean subtraction). It only normalizes by the root mean square, making it computationally cheaper while maintaining similar performance.

#### Mathematical Formula

For an input vector $x \in \mathbb{R}^d$:

$$
\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma
$$

Where:
- $\text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + \epsilon}$
- $\gamma$ is a learnable scale parameter (initialized to ones)
- $\epsilon$ is a small constant for numerical stability

#### Why RMSNorm?

1. **Faster**: No mean computation required
2. **Simpler**: Fewer operations than LayerNorm
3. **Effective**: Empirically works just as well for LLMs

---

### Requirements

1. **Define an `RMSNorm` class** that:
   - Inherits from `nn.Module`
   - Has a learnable scale parameter `gamma` of shape `(dim,)`
   - Applies RMS normalization along the last dimension

2. **Handle arbitrary batch dimensions**:
   - Input shape: `(..., dim)` where `...` represents any number of batch dimensions
   - Output shape: same as input

3. **Test your implementation**:
   - Verify output shape matches input shape
   - Compare behavior with expected normalization properties

---

### Constraints

- ‚úÖ Use only PyTorch operations
- ‚úÖ The scale parameter must be learnable (`nn.Parameter`)
- ‚úÖ Must work with any input shape `(..., dim)`
- ‚ùå Do NOT subtract the mean (that would be LayerNorm)

---

<details>
  <summary>üí° Hint 1: Computing RMS</summary>
  Use <code>x.pow(2).mean(dim=-1, keepdim=True)</code> to compute the mean of squared values along the last dimension.
</details>

<details>
  <summary>üí° Hint 2: Numerical Stability</summary>
  Add epsilon <strong>inside</strong> the square root: <code>torch.sqrt(mean_sq + eps)</code>
</details>

<details>
  <summary>üí° Hint 3: Learnable Parameter</summary>
  Initialize gamma with <code>nn.Parameter(torch.ones(dim))</code> so it starts as identity scaling.
</details>

---


In [1]:
import torch
import torch.nn as nn


In [4]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-8):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return (x / norm) * self.gamma


In [5]:
# Test your implementation
x = torch.randn(3, 5)  # (batch_size=3, features=5)
rmsnorm = RMSNorm(dim=5)
out = rmsnorm(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")

assert out.shape == (3, 5), f"Output shape mismatch: expected (3, 5), got {out.shape}"
print("‚úì Shape test passed!")

# Test with different batch dimensions
x_3d = torch.randn(2, 4, 5)  # (batch, seq_len, dim)
out_3d = rmsnorm(x_3d)
assert out_3d.shape == (2, 4, 5), (
    f"3D shape mismatch: expected (2, 4, 5), got {out_3d.shape}"
)
print("‚úì 3D shape test passed!")

# Verify learnable parameters exist
params = list(rmsnorm.parameters())
assert len(params) == 1, f"Expected 1 learnable parameter (gamma), got {len(params)}"
assert params[0].shape == (5,), (
    f"Gamma shape mismatch: expected (5,), got {params[0].shape}"
)
print("‚úì Parameter test passed!")

print("\nüéâ All tests passed!")


Input shape: torch.Size([3, 5])
Output shape: torch.Size([3, 5])
‚úì Shape test passed!
‚úì 3D shape test passed!
‚úì Parameter test passed!

üéâ All tests passed!
