# Core Ideas of Nvidia's N-GPT

This notebook guide is designed for people who are already confident with modern transformers (ex. Llama3). If you are a complete beginner, check out my [Llama3 tutorial](https://colab.research.google.com/drive/10BKvPomnVVZw7UAT3wOaaPBdvfMEvOOY?usp=sharing).
for an accelerated introduction designed for those who already understand basic math concepts like matrix multiplication or [Andrej Karpathy's "Neural Networks: Zero to Hero" course](https://youtube.com/playlist?list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&si=8Z9BUgdFAnGBo71c) for those who need to start from scratch. The purpose of this guide is to provide intuition behind the architecture choices implemented in [Nvidia's N-GPT](https://arxiv.org/abs/2410.01131v1) without getting into the particulars (for that, read model.py). 

Check out the YouTube video where i walk through the paper:
\[\!\[ERROR DISPLAYING IMAGE, CLICK HERE FOR VIDEO](https://img.youtube.com/vi/lZj8F6EspVU/0.jpg)](https://www.youtube.com/watch?v=lZj8F6EspVU)

**Note:** It's very easy to convince yourself that you understand something after watching a youtube video about it, but chances are you don't actually understand unless you can write out the math and code it from scratch on your own. I highly recommend doing so

### Setup stuff

imports & hyperparameters & whatnot

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass

In [12]:
@dataclass
class ModelConfig:
    dim: int = 8 # the model's embedding dimension
    device: str = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
        # defaults to best available GPU/CPU
    max_seq_len: int = 5 # maximum number of tokens in the context
    theta: float = 10_000. # RoPE hyperparameter; 10_000 is the most common choice
    vocab_len: int = 2048 # options are 512, 1024, 2048
    num_layers: int = 4 # number of layers in the model
    num_heads: int = 2 # number of heads in the multi-head attention mechanism
    mlp_hidden_mult: float = 1.5 # how wide the hidden dimension of the MLP should be compared to dim


@dataclass
class TrainConfig:
    batch_size: int = 3 
    max_iters: int = 100 # total number of batches to run over the course of training
    # AdamW Hyperparameters https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
    beta1: float = 0.9
    beta2: float = 0.95
    epsilon: float = 1e-8
    # N-GPT disables weight-decay in the optimizer since it would move vectors off of the unit-hypersphere
    weight_decay: float = 0.0 
    # Maximum and minimum learning rates during annealing
    lr_init: float = 5e-3 # N-GPT does NOT need to use learning rate warmup bc training is so stable
    lr_final: float = 1e-5

cfg = ModelConfig()
tcfg = TrainConfig()

### The broad ideas
<a id='2'></a>
The two key innovations of this architecture and the result of implementing them, as stated by the authors, are as follows

![picture](images/key_contributions.png)

Putting vectors onto the unit hypersphere in classification settings has been popular for quite awhile in the computer vision field, for some reason especially so with medical imaging. Transformers themselves being in-context optimizers has been a hot topic of debate for in-context learning research, but from what i've read it does seem to have merit at least in some contexts, and what they did here was design the model in a manner that embraces this idea



### Cosine Normalization

The idea here is that whereas traditional GPT models use either [Layer Norm](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html) or [RMS Norm](https://arxiv.org/abs/1910.07467) and do so usually only on the residual stream vectors, here we'll be using cosine normalization which both 1) does "more" normalization and 2) will also be applied to the weight matrices along their embedding-length dimension.

LayerNorm and RMSNorm are relatively similar, both placing vectors onto the hypersphere of radius $\sqrt{dim}$, with the main difference being that the former also centers those spheres around the origin while the latter does not, effectively meaning it's putting each vector onto the surface of different hyperspheres. See my in-depth video explanation of what LayerNorm does to vectors [here](https://youtu.be/vlgLbQtL1RE)

Cosine normalization on the other hand places all vectors onto the same hypersphere centered at the origin with radius $1$, aka the unit-hypersphere. In the code below we can see it's relatively simple; just divide the vector by its own norm

But how do we actually go about ensuring that all of our weight matrices have been cosine-normalized? The two implementations I'm aware of are that it can either be done in the forward pass before using said weights, or by forcing it after every iteration of gradient descent in the training loop. For efficiency during training I'm not sure it matters other than maybe the iteration through nn.Modules might be slower, and for inference it's definitely faster to not have to call the cosine_norm function every single forward pass. Over in `training.ipynb` I've gone the route of doing it during the training loop right after the gradient update.

In [16]:
def cosine_norm(x: torch.Tensor, dim=-1) -> torch.Tensor:
    # calculate the magnitude of the vectors
    norm = torch.norm(x, p=2, dim=dim, keepdim=True).clamp(min=1e-6)
    # divide by the magnitude to place on the unit hypersphere
    return x / norm

In [18]:
# Example usage of cosine_norm on a residual state tensor and a weights matrix
residual_state = torch.randn(tcfg.batch_size, cfg.max_seq_len, cfg.dim, device=cfg.device)
weights_matrix = torch.randn(cfg.dim, int(cfg.dim * cfg.mlp_hidden_mult), device=cfg.device)

print("residual state norms:", torch.norm(residual_state, p=2, dim=-1, keepdim=True).clamp(min=1e-6))
print("weights matrix norms:", torch.norm(weights_matrix, p=2, dim=0, keepdim=True).clamp(min=1e-6))

# Normalize the residual state tensor
normalized_residual_state = cosine_norm(residual_state)

# Normalize the weights matrix along the embedding dimension (dim=0)
normalized_weights_matrix = cosine_norm(weights_matrix, dim=0)

print("Normalized residual state norms:", torch.norm(normalized_residual_state, p=2, dim=-1, keepdim=True).clamp(min=1e-6))
print("Normalized weights matrix norms:", torch.norm(normalized_weights_matrix, p=2, dim=0, keepdim=True).clamp(min=1e-6))

residual state norms: tensor([[[2.6883],
         [2.4465],
         [2.0404],
         [1.9038],
         [3.4173]],

        [[1.8770],
         [2.3548],
         [2.3282],
         [3.9562],
         [3.0683]],

        [[3.7325],
         [2.2250],
         [2.1887],
         [0.9713],
         [2.4724]]], device='mps:0')
weights matrix norms: tensor([[2.6406, 1.3802, 3.5841, 2.0801, 2.0012, 1.3385, 2.4916, 2.3445, 3.0002,
         2.3255, 2.5163, 2.0334]], device='mps:0')
Normalized residual state norms: tensor([[[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]]], device='mps:0')
Normalized weights matrix norms: tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000]], device='mps:0')


### 2b. Interpreting Matmuls
This is what matmuls look like with and without cosine normalization. Notice in the code example below how the former is technically unbounded, although realistically it clusters around 1 with some expected variance, while the latter is bounded in $[-1,1]$. Dot-products between two vectors on the unit-hypersphere can be interpreted as performing cosine similarity between them where a value of $-1$ corresponds to completely opposite vectors, $0$ corresponds to orthogonal, and $1$ is what you would get if they're the exact same vector.

Here's the traditional cosine normalization formula for non-normalized vectors where $\cdot$ denotes dotproduct, $||a||$ denotes the magnitude of $a$, and $\times$ is scalar multiplication:
$$ \frac{a \cdot b}{||a||\times||b||} $$

And since a cosine normalized vector $a$ has already been divided by $||a||$, the cosine similarity formula simplifies down to
$$a\cdot b$$

This property might potentially open up interesting future avenues for interpretability researach since we can now understand matmuls as checking the similarity between vectors in the input tensor and vectors in the weight tensor. What vectors might the model think it important to compare to the input when deciding how to edit the input?

In [24]:
matmul = residual_state @ weights_matrix
normalized_matmul = normalized_residual_state @ normalized_weights_matrix

# comparing the value ranges of matmul & normalized_matmul
print("matmul value range:", matmul.min().item(), matmul.max().item())
print("normalized_matmul value range:", normalized_matmul.min().item(), normalized_matmul.max().item())

matmul value range: -6.9036760330200195 7.527524948120117
normalized_matmul value range: -0.7024921178817749 0.8642598986625671


### 2c. Transformers as variable-metric optimizers
<a id='4'></a>

The idea here is to interpret each residual connection (whether it be attention mechanism or multi-layer perceptron) as itself calculating a gradient in the direction of the final to-be-outputted token. Prior interpretations of this sort that I'm aware of have focused more on understanding transformers through this lens, but this work actually attempts to shape these residual connections to better perform this function.

The traditional residual connection update equation looks something like:
$$ h_{l+1} = h_l + h_l^A$$
$$ h_{l+1} = h_l + h_l^M$$
where $h_l$ is the hidden state at the $l$'th layer and $h_l^A = \text{Attention}(\text{RMSNorm}(h_l))$ and $h_l^M = \text{MLP}(\text{RMSNorm}(h_l))$ denote the output of the attention and MLP respectively.

The natural question to ask once viewing residual connections this way is: how large of a step are these models taking in the direction of the gradient? Prior GPT models would have to implicitly incorporate this decision into the attention mechanism or MLP itself, so if that function could instead be separated out then it should leave more of the model's regular parameters to do the gradient part. When we re-phrase the goal in terms of variable-metric optimizers and also make our cosine-normalization adjustment, we get
$$ h_{l+1} = h_l + a_A * g_A $$
$$ h_{l+1} = h_l + a_M * g_M $$
where $g_A = h_l^A - h_l$ and $g_M = h_l^M - h_l$ can be interpreted as the gradients from their respective modules, $*$ denotes entry-wise multiplication, and $a_A$ and $a_M$ are parameters determining the size of the gradient steps for their respective modules, analagous to $\eta$ in actual gradient descent. The authors call them "eigen" learning rate vectors, a name based in etymology rather than any relation to eigenvalues and eigenvectors, which I think is very confusing and a bad choice on their part. 

In [30]:
class MiniTransformerLayer(nn.Module):
    def __init__(self, dim, device):
        super().__init__()
        # let's pretend for a second that this is an entire multi-layer perceptrion instead of a single linear layer
        self.MLP = nn.Linear(dim,  dim, bias=False, device=device)

        # and now our eigen learning rate vector, which we initialize to a value of a_M_scale for all entries
        self.a_M = nn.Parameter(torch.ones(dim, device=device))

    def forward(self, h_l: torch.Tensor) -> torch.Tensor:
        # first run the actual multi-layer perceptron
        h_M = self.MLP(h_l)
        # finally do the actual residual layer update
        h_lplus1 = h_l + self.a_M * (h_M - h_l)
        return h_lplus1

h_l = torch.randn(tcfg.batch_size, cfg.max_seq_len, cfg.dim, device=cfg.device)
mini_transformer_layer = MiniTransformerLayer(cfg.dim, cfg.device)
h_lplus1 = mini_transformer_layer(h_l)

### 2d. Scaling parameters
<a id='5'></a>

These edits are all well and good, but the astute observer might have been concerned when I showed earlier in section 2b that the distributions of the outputs of our operations are very different from what they would be without all these edits. I won't get too deep into the issues presented here (see pages 5 and 12 of the paper for more), but what we're about to do here essentially amounts to controlling the learning rate of a couple specific key parameters in the model without affecting all of the others.

Here's how the authors describe the methodology with the example of $a_A$ and $a_M$. It took me a few reads to understand what's going on since I had no code to look at, but hopefully my replication will help make the process clear for you.

![alt text](images/scaling_parameters.png)

In [36]:
class MiniTransformerLayer(nn.Module):
    def __init__(self, dim, device):
        super().__init__()
        # let's pretend for a second that this is an entire multi-layer perceptrion instead of a single linear layer
        self.MLP = nn.Linear(dim,  dim, bias=False, device=device)

        # define our scaling parameters
        self.a_M_scale = 1. / math.sqrt(dim)
        self.a_M_init = 1.

        # and now our eigen learning rate vector, which we initialize to a value of a_M_scale for all entries
        self.a_M = nn.Parameter(torch.ones(dim, device=device) * self.a_M_scale)

    def forward(self, h_l: torch.Tensor) -> torch.Tensor:
        # first run the actual multi-layer perceptron
        h_M = self.MLP(h_l)
        # then calculate our effective scaling parameter
        effective_a_M = self.a_M * (self.a_M_init / self.a_M_scale)
        # finally do the actual residual layer update
        h_lplus1 = h_l + effective_a_M * (h_M - h_l)
        return h_lplus1

h_l = torch.randn(tcfg.batch_size, cfg.max_seq_len, cfg.dim, device=cfg.device)
mini_transformer_layer = MiniTransformerLayer(cfg.dim, cfg.device)
h_lplus1 = mini_transformer_layer(h_l)

# aight, i guess now go read the actual code to see all the specifics