## Imports

In [56]:
import torch
from torch import nn
from torch.nn import functional as F
from dataclasses import dataclass
from typing import List, Tuple, Optional
from collections import OrderedDict

torch.set_printoptions(precision=3, sci_mode=False, linewidth=160)

## Topics
- RMS Normalization
- Rotary Positional Embeddings
- KV-Cache
- Multi-Query Attention
- Grouped Multi-Query Attention
- SwiGLU Activation Function

<img src="assets/LLaMA-Architecture.png" alt="Drawing" style="width: 800px;"/>

In [3]:
@dataclass
class LlamaConfig:
    vocab_size: int = -1
    hidden_size: int = 4096
    n_layers: int = 32
    n_attention_heads: int = 32 # number of attention heads for queries
    n_key_value_heads: Optional[int] = None # number of attention heads for keys and values
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    intermediate_size: int = 16384
    norm_eps: float = 1e-5
    dropout: float = 0.1
    
    max_batch_size: int = 32
    max_seq_len: int = 2048
    device: str = None

In [4]:
class LlamaBlock(nn.Module):
    """"""
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        
        pass

In [5]:
class RMSNorm(nn.Module):
    
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        
        self.config = config
        self.eps = config.norm_eps
        self.register_parameter("scale", nn.Parameter(torch.ones(config.hidden_size)))

### RoPE

#### Rotary Positional Embeddings: Combining Absolute and Relative

- Introduction
  - Discusses the importance of positional embeddings in Transformer models.
  
- Absolute Positional Embeddings
  - Explains how absolute positional embeddings work.
  - Highlights limitations like fixed sequence length and lack of relative context.
  
- Relative Positional Embeddings
  - Introduces the concept of relative positional embeddings.
  - Discusses the computational challenges and inefficiencies.
  
- Rotary Positional Embeddings (RoPE)
  - Combines the advantages of both absolute and relative embeddings.
  - Uses rotation to encode position, preserving relative distances.
  
- Matrix Formulation
  - Explains the mathematical formulation behind RoPE.
  
- Implementation
  - Shows how RoPE can be implemented efficiently in PyTorch.
  
- Experiments and Conclusion
  - Shares results of experiments showing RoPE's effectiveness and efficiency compared to other methods.

The video provides a comprehensive overview of Rotary Positional Embeddings, a new method that combines the strengths of both absolute and relative positional embeddings. It delves into the mathematical details and practical implementation, concluding with experimental results that validate its effectiveness.

In [60]:
theta = 10000.0
max_seq_len = 10
hidden_size = 14

$\theta_i = 1000^{(-2i) / hidden_size} $

In [61]:
# theta_i = 1000 ^ (-2(i-1) / hidden_size) for i in [1, 2, ..., hidden_size/2]
thetas = (1 / theta) ** (torch.arange(0, hidden_size, 2) / hidden_size) #[hidden_size//2, ]
thetas.shape

torch.Size([7])

In [62]:
thetas

tensor([    1.000,     0.268,     0.072,     0.019,     0.005,     0.001,     0.000])

In [63]:
m = torch.arange(max_seq_len)
m

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [64]:
m_thetas = torch.outer(m, thetas).float() #[max_seq_len, hidden_size//2]
m_thetas.shape

torch.Size([10, 7])

In [65]:
m_thetas

tensor([[    0.000,     0.000,     0.000,     0.000,     0.000,     0.000,     0.000],
        [    1.000,     0.268,     0.072,     0.019,     0.005,     0.001,     0.000],
        [    2.000,     0.537,     0.144,     0.039,     0.010,     0.003,     0.001],
        [    3.000,     0.805,     0.216,     0.058,     0.016,     0.004,     0.001],
        [    4.000,     1.073,     0.288,     0.077,     0.021,     0.006,     0.001],
        [    5.000,     1.341,     0.360,     0.097,     0.026,     0.007,     0.002],
        [    6.000,     1.610,     0.432,     0.116,     0.031,     0.008,     0.002],
        [    7.000,     1.878,     0.504,     0.135,     0.036,     0.010,     0.003],
        [    8.000,     2.146,     0.576,     0.154,     0.041,     0.011,     0.003],
        [    9.000,     2.414,     0.648,     0.174,     0.047,     0.013,     0.003]])

In [66]:
freqs_complex = torch.polar(torch.ones_like(m_thetas), m_thetas)
freqs_complex.shape

torch.Size([10, 7])

In [67]:
freqs_complex

tensor([[ 1.000+0.000j,  1.000+0.000j,  1.000+0.000j,  1.000+0.000j,  1.000+0.000j,  1.000+0.000j,  1.000+0.000j],
        [ 0.540+0.841j,  0.964+0.265j,  0.997+0.072j,  1.000+0.019j,  1.000+0.005j,  1.000+0.001j,  1.000+0.000j],
        [-0.416+0.909j,  0.859+0.511j,  0.990+0.143j,  0.999+0.039j,  1.000+0.010j,  1.000+0.003j,  1.000+0.001j],
        [-0.990+0.141j,  0.693+0.721j,  0.977+0.214j,  0.998+0.058j,  1.000+0.016j,  1.000+0.004j,  1.000+0.001j],
        [-0.654-0.757j,  0.477+0.879j,  0.959+0.284j,  0.997+0.077j,  1.000+0.021j,  1.000+0.006j,  1.000+0.001j],
        [ 0.284-0.959j,  0.227+0.974j,  0.936+0.352j,  0.995+0.096j,  1.000+0.026j,  1.000+0.007j,  1.000+0.002j],
        [ 0.960-0.279j, -0.039+0.999j,  0.908+0.419j,  0.993+0.116j,  1.000+0.031j,  1.000+0.008j,  1.000+0.002j],
        [ 0.754+0.657j, -0.302+0.953j,  0.876+0.483j,  0.991+0.135j,  0.999+0.036j,  1.000+0.010j,  1.000+0.003j],
        [-0.146+0.989j, -0.544+0.839j,  0.839+0.544j,  0.988+0.154j,  0.999+0.04

### Apply Rotary Positional Embeddings

In [70]:
x = torch.randn(1, 1, hidden_size)
x, x.shape

(tensor([[[-0.829,  0.845,  0.287,  0.847,  1.093, -1.112, -1.014, -0.143, -0.983, -1.131,  1.792,  0.058,  0.219,  0.337]]]),
 torch.Size([1, 1, 14]))

In [73]:
x_reshaped = x.float().view(*x.shape[:-1], -1, 2)
x_reshaped

tensor([[[[-0.829,  0.845],
          [ 0.287,  0.847],
          [ 1.093, -1.112],
          [-1.014, -0.143],
          [-0.983, -1.131],
          [ 1.792,  0.058],
          [ 0.219,  0.337]]]])

In [75]:
x_reshaped.shape

torch.Size([1, 1, 7, 2])

In [76]:
x_complex = torch.view_as_complex(x_reshaped)
x_complex

tensor([[[-0.829+0.845j,  0.287+0.847j,  1.093-1.112j, -1.014-0.143j, -0.983-1.131j,  1.792+0.058j,  0.219+0.337j]]])

In [77]:
x_complex.shape

torch.Size([1, 1, 7])

In [80]:
position = 1
freq_complex = freqs_complex[position:position+1][None, :, None, ...] # (seq_len, head_dim//2) -> (1, seq_len, 1, head_dim//2)
freq_complex

tensor([[[[0.540+0.841j, 0.964+0.265j, 0.997+0.072j, 1.000+0.019j, 1.000+0.005j, 1.000+0.001j, 1.000+0.000j]]]])

In [82]:
x_rotated = x_complex * freq_complex # (B,seq_len,h,head_dim//2)
x_rotated.shape

torch.Size([1, 1, 1, 7])

In [83]:
x_rotated

tensor([[[[-1.158-0.241j,  0.052+0.893j,  1.170-1.031j, -1.011-0.162j, -0.977-1.136j,  1.792+0.060j,  0.218+0.337j]]]])

In [85]:
x_real = torch.view_as_real(x_rotated) # (B,seq_len,h,head_dim//2) -> (B, seq_len, h, head_dim//2, 2)
x_real

tensor([[[[[-1.158, -0.241],
           [ 0.052,  0.893],
           [ 1.170, -1.031],
           [-1.011, -0.162],
           [-0.977, -1.136],
           [ 1.792,  0.060],
           [ 0.218,  0.337]]]]])

In [86]:
res = x_real.view(x.shape).type_as(x)
res

tensor([[[-1.158, -0.241,  0.052,  0.893,  1.170, -1.031, -1.011, -0.162, -0.977, -1.136,  1.792,  0.060,  0.218,  0.337]]])

In [87]:
res.shape, x.shape

(torch.Size([1, 1, 14]), torch.Size([1, 1, 14]))

### RMS Normalization

Root Mean Square Layer Normalization (RMSNorm) emerges as a refinement over traditional normalization techniques employed in transformer architectures, aiming to address computational overheads while retaining, or even enhancing, the model's performance.  
Initially, transformers relied on BatchNormalization, as introduced in Vaswani et al. 2017, but the advent of LLaMA brought RMSNorm to the forefront, offering a more streamlined normalization approach.  

The core idea behind RMSNorm is derived from its predecessor, Layer Normalization (LayerNorm), which was noted for its two pivotal properties: re-centering and re-scaling. Re-centering aids in rendering the model robust to shift noises in inputs and weights,  
while re-scaling preserves the output representations amidst random scaling of inputs and weights. However, it was recognized that the majority of the benefits stemmed from the re-scaling aspect, which led to the conceptualization of RMSNorm.  

RMSNorm, embodying simplicity, omits the re-centering (mean-centering) operation from LayerNorm, focusing solely on the re-scaling invariance. This is achieved by normalizing the summed inputs according to the root mean square (RMS) statistic, expressed as:  
$$ \overline{a}_i = \frac{a_i}{\texttt{RMS}}$$ where $$ \texttt{RMS} = \sqrt{ \frac{1}{n}\sum_{i=1}^n a_i^2} $$  
Here, $a_i$ denotes the activation of the ith neuron. By sidestepping the mean computation, RMSNorm not only simplifies the normalization process but also curtails the computational time, marking a significant stride towards efficiency.   

This nuanced approach dovetails with the transition from post-normalization to pre-normalization, as observed in LLaMA. Unlike the original Transformer that applied normalization post the attention layer, LLaMA, inspired by GPT-3,   
normalizes the input before feeding it to the self-attention and feed-forward layers. RMSNorm serves as the linchpin in this pre-normalization process, ensuring that the neural network's activations are duly regulated before entering the subsequent layers.  

The impact of RMSNorm is palpable, with empirical evidence indicating a reduction in running time between 7% to 64% when juxtaposed with LayerNorm, all the while maintaining comparable performance levels. This demonstrates RMSNorm's capability to balance computational   
efficiency with performance efficacy, making it a valuable asset in the ongoing evolution of transformer architectures.  

<img src="assets/RMS-Normalization.png" alt="Drawing" style="width: 800px;"/>

In [127]:
class RMSNorm(nn.Module):

    def __init__(self, config: LlamaConfig):
        """
        Root Mean Square Layer Normalization

        Args:
            config: LlamaConfig
        """

        super().__init__()
        self.eps = config.norm_eps
        self.register_parameter('scale', nn.Parameter(torch.ones(config.hidden_size)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, seq_len, hidden_size)
        x_normed = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) # (B,seq_len,hidden_size) * (B,seq_len,1) -> (B,seq_len,hidden_size)
        return (self.scale * x_normed).as_type(x) # (hidden_size) * (B,seq_len,hidden_size) -> (B,seq_len,hidden_size)

In [136]:
x = torch.randn(10, 1, hidden_size)
weight = torch.rand(hidden_size)

x.shape, weight.shape

(torch.Size([10, 1, 14]), torch.Size([14]))

In [137]:
x_normed = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-6)
x_normed

tensor([[[-1.715, -0.063,  1.662,  0.384, -0.274,  0.199, -1.481,  0.306,  0.391, -1.988, -0.609,  1.110,  0.009, -0.184]],

        [[ 0.133,  1.512,  0.952, -0.218,  0.354,  1.264,  0.758,  0.051, -1.049, -0.243, -0.041,  0.474,  1.051,  2.439]],

        [[ 0.035,  0.812,  0.134,  0.358,  0.978, -0.392, -1.400, -0.069, -1.031, -2.237,  1.066,  0.061, -1.204,  1.207]],

        [[ 0.608, -1.350, -2.355,  1.070, -0.426,  0.727, -0.118, -0.769,  1.163, -0.412, -0.699,  0.546,  0.581,  1.074]],

        [[ 0.221,  1.581, -0.584, -1.565, -0.122,  1.393, -0.101, -0.473,  0.566,  0.271,  1.273,  1.366,  0.762,  1.418]],

        [[ 0.628,  0.176,  0.488, -0.035, -0.174, -1.080, -1.616,  1.695, -1.271, -1.166, -0.094, -0.720,  0.202, -1.763]],

        [[ 1.114,  0.312, -0.263, -1.128,  1.112, -0.910, -0.882, -2.272, -0.046,  1.052, -0.366,  0.364, -0.174, -1.381]],

        [[-0.022, -1.279, -0.660,  0.648,  1.433, -0.345,  0.363, -0.753,  0.016,  1.520,  1.251, -0.635, -0.564, -2.010]],



In [138]:
res = weight * x_normed
res

tensor([[[    -1.495,     -0.041,      1.045,      0.023,     -0.258,      0.182,     -0.246,      0.003,      0.041,     -0.721,     -0.335,      0.231,
               0.008,     -0.106]],

        [[     0.116,      0.980,      0.599,     -0.013,      0.334,      1.155,      0.126,      0.001,     -0.109,     -0.088,     -0.022,      0.099,
               0.940,      1.403]],

        [[     0.030,      0.526,      0.084,      0.021,      0.924,     -0.359,     -0.233,     -0.001,     -0.108,     -0.811,      0.586,      0.013,
              -1.077,      0.694]],

        [[     0.530,     -0.875,     -1.480,      0.063,     -0.402,      0.665,     -0.020,     -0.008,      0.121,     -0.149,     -0.384,      0.114,
               0.520,      0.618]],

        [[     0.193,      1.025,     -0.367,     -0.093,     -0.115,      1.274,     -0.017,     -0.005,      0.059,      0.098,      0.700,      0.284,
               0.682,      0.815]],

        [[     0.548,      0.114,      0.307,

In [78]:
class LLaMA(nn.Module):
    
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        
        assert config.vocab_size != -1, 'vocab_size must be specified'

        self.device = config.device
        
        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 
        self.freq_complex = self._precompute_pos_frequencies()
        self.llama_blocks = nn.Sequential(
            OrderedDict([(f"llama_{i}", LlamaBlock(config)) for i in range(config.n_layers)])
        )
        self.rms_norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
        self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    def precompute_pos_frequencies(self, config: LlamaConfig) -> torch.Tensor:
        """Precompute positional frequencies for sinusoidal positional embeddings."""
        
        theta = 10000.0
        hidden_size = config.hidden_size
        max_seq_len = config.max_seq_len
        device = config.device
        
        assert hidden_size % 2 == 0, 'hidden_size must be even: RoPe cannot be appied to odd-dimensional embeddings'
        
        # theta_i = 1000 ^ (-2(i) / hidden_size) for i in [1, 2, ..., hidden_size/2]
        thetas = (1 / theta) ** (torch.arange(0, hidden_size, 2) / hidden_size) #[hidden_size//2, ]
        m = torch.arange(max_seq_len, device=device, dtype=torch.float) # (max_seq_len, )
        freqs = torch.outer(m, theta) # (max_seq_len, hidden_size/2)
        freqs_complex = torch.polar(torch.ones_like(freqs), freqs).to(device) # (max_seq_len, hidden_size/2)
        return freqs_complex

    def apply_rotary_embeddings(self, x: torch.Tensor, position: int) -> torch.Tensor:
        x_complex = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) # (B,seq_len,h,head_dim) -> (B,seq_len,h,head_dim//2)
        freq_complex = self.freq_complex[position:position+1][None, :, None, ...] # (seq_len, head_dim//2) -> (1, seq_len, 1, head_dim//2)
        x_rotated = x_complex * freq_complex # (B,seq_len,h,head_dim//2)
        x_real = torch.view_as_complex(x_rotated) # (B,seq_len,h,head_dim//2) -> (B, seq_len, h, head_dim//2, 2)
        return x_real.view(x.shape).type_as(x).to(self.device)
        
        
    
        
    def forward(self, input_ids: torch.Tensor, start_position: int, target: torch.Tensor = None) -> torch.Tensor:
        # input_ids: (batch_size, seq_len)
        bs, seq_len = input_ids.shape
        assert seq_len == 1, 'sequence length must be 1'
        
        token_embeddings = self.embeddings(input_ids) # (batch_size, seq_len, hidden_size): (bs, 1, 4096)
        freq_complex = self.freq_complex[start_position:start_position + seq_len] # (batch_size, hidden_size)
        
        for layer in self.llama_blocks:
            token_embeddings = layer(token_embeddings, start_position, freq_complex)
        logits = self.head(self.rms_norm(token_embeddings))
        
        if target is None:
            return {'logits': logits}
        else:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1))
            return {'logits': logits, 'loss': loss}