# Llama Code

## RMSNorm

Suppose $\mathbf{X}\in\mathbb{R}^{optional\times d}$, last dimension is $d$. RMSNorm performs on the last dimension, with scalars $a_{1},\dots,a_{d}$:

$$
\begin{aligned}
\text{RMS}(\mathbf{a}) &= \sqrt{\frac{1}{d}\sum_{i=1}^{d}a_{i}^{2} + \epsilon}\\
\bar{a_{i}} &= \frac{a_{i}}{\text{RMS}(\mathbf{a})}w_{i}
\end{aligned}
$$

Weight $\mathbf{w}\in\mathbb{R}^{d}$.

In [1]:
import torch
from torch import nn


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        Initialize the RMSNorm normalization layer.

        Args:
            dim (int): The dimension of the input tensor.
            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

        Attributes:
            eps (float): A small value added to the denominator for numerical stability.
            weight (nn.Parameter): Learnable scaling parameter.

        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """
        Apply the RMSNorm normalization to the input tensor.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The normalized tensor.

        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """
        Forward pass through the RMSNorm layer.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor after applying RMSNorm.

        """
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

In [8]:
dim = 4
x = torch.randn(2, 3, 4)
x

tensor([[[ 1.0185, -0.0323, -0.1116,  0.2731],
         [ 0.4527, -0.8262,  0.5210,  1.4562],
         [-0.2279, -1.4570,  0.3255,  0.8584]],

        [[-1.2045,  0.0891,  0.3982, -0.0681],
         [-1.2457,  0.6341,  1.8910,  0.4150],
         [ 1.5039, -0.9606,  1.0872, -0.3793]]])

In [9]:
x.pow(2)

tensor([[[1.0374e+00, 1.0458e-03, 1.2463e-02, 7.4576e-02],
         [2.0498e-01, 6.8258e-01, 2.7140e-01, 2.1205e+00],
         [5.1918e-02, 2.1227e+00, 1.0594e-01, 7.3680e-01]],

        [[1.4509e+00, 7.9472e-03, 1.5857e-01, 4.6335e-03],
         [1.5517e+00, 4.0211e-01, 3.5758e+00, 1.7223e-01],
         [2.2617e+00, 9.2283e-01, 1.1820e+00, 1.4383e-01]]])

In [10]:
x.pow(2).mean(-1, keepdim=True)

tensor([[[0.2814],
         [0.8199],
         [0.7543]],

        [[0.4055],
         [1.4255],
         [1.1276]]])

In [12]:
torch.rsqrt(x.pow(2).mean(-1, keepdim=True))

tensor([[[1.8852],
         [1.1044],
         [1.1514]],

        [[1.5704],
         [0.8376],
         [0.9417]]])

## Rotary Embedding

In [15]:
from typing import Tuple

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings.
        xk (torch.Tensor): Key tensor to apply rotary embeddings.
        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )