<a href="https://colab.research.google.com/github/kapilsh/gpt-oss-scratch/blob/main/gpt_oss_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT OSS Architecture

![GPT OSS Architecture](./resources/AI%20Knowledge%20Bank.jpg)

Let's build out these components one by one. In some cases, we can just see how the individual modules will work but there are already PyTorch version available so we will use that at the end

## RMS Norm

https://arxiv.org/pdf/1910.07467

![RMS Norm](./resources/rms_norm.png)

PyTorch already has an implementation of RMSNorm in `torch.nn.RMSNorm` but we can check our sample implementation wrt PyTorch one.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class RMSNorm(nn.Module):
    def __init__(self, embedding_dimension: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.embedding_dimension = embedding_dimension
        self.weight = nn.Parameter(torch.ones(embedding_dimension))

    def forward(self, x: torch.Tensor):
        means = x.pow(2).mean(dim=-1, keepdim=True)
        return (x * torch.rsqrt(means + self.eps)) * self.weight

In [2]:
x = torch.tensor(
    [
        [0.1200, -0.5600, 1.3400, 0.2200, -1.0500, 0.8700, -0.4400, 0.0900],
        [-0.9800, 0.4500, -0.1100, 0.6600, 0.7300, -0.3500, 0.2800, -0.6200],
        [0.5300, -0.7700, 0.0800, -0.3400, 0.1900, 1.0200, -0.5900, 0.4100],
        [1.2100, -0.1400, 0.3800, 0.7700, -0.9100, 0.6400, -0.2700, 0.1500],
    ]
)

rms_norm = RMSNorm(embedding_dimension=x.shape[-1])
rms_norm.weight.data = torch.tensor(
    [0.8, 1.2, 0.9, 1.1, 0.7, 1.3, 1.0, 0.95]
)  # Setting custom weights for demonstration
rms_norm_out = rms_norm(x)
print("RMSNorm Output:\n", rms_norm_out)

RMSNorm Output:
 tensor([[ 0.1320, -0.9238,  1.6579,  0.3327, -1.0104,  1.5548, -0.6049,  0.1175],
        [-1.3424,  0.9246, -0.1695,  1.2431,  0.8749, -0.7790,  0.4794, -1.0085],
        [ 0.7454, -1.6244,  0.1266, -0.6575,  0.2338,  2.3311, -1.0372,  0.6847],
        [ 1.4523, -0.2520,  0.5131,  1.2707, -0.9557,  1.2482, -0.4051,  0.2138]],
       grad_fn=<MulBackward0>)



### Visualization

Here's is visualization of how RMSNorm works in practice taking a batch_size = 4 and embedding size = 8 example:

![RMS Norm Viz](./resources/rms_norm_viz.png)

### Test

In [3]:
torch.manual_seed(123)

example_batch = torch.randn(2, 3, 4)

rms_norm = RMSNorm(embedding_dimension=example_batch.shape[-1])
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)

assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))