In [1]:
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def precompute_theta_pos_frequencies(head_dim:int, seq_len:int, device:str, theta:float=10000.0):
    assert head_dim % 2 == 0, "Dimension (divided with head) must be divisible by 2"
    
    theta_numerator = torch.arange(0, head_dim, 2).float()
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)
    m = torch.arange(seq_len, device=device)
    freqs = torch.outer(m, theta).float() # (seq_len) outer product* (head_dim / 2) -> (seq_len, head_dim/2)
    
    # torch.polar = abs⋅cos(angle)+abs⋅sin(angle)⋅j
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

In [3]:
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex:torch.Tensor, device:str):
    # (B, seq_len, num_heads, head_dim) -> (B, seq_len, num_heads, head_dim/2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    # (seq_len, head_dim/2) -> (1, seq_len, 1, head_dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    # (B, seq_len, num_heads, head_dim/2) * (1, seq_len, 1, head_dim/2) = (B, seq_len, num_heads, head_dim/2)
    x_rotated = x_complex * freqs_complex
    # (B, seq_len, num_heads, head_dim/2) -> (B, seq_len, num_heads, head_dim/2, 2)
    x_out = torch.view_as_real(x_rotated)
    # (B, seq_len, num_heads, head_dim/2, 2) -> (B, seq_len, num_heads, head)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)    

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, epsilon: float = 1e-8):
        super().__init__()
        self.dim = dim
        self.eps = epsilon
        self.weight = nn.Parameter(torch.ones(dim))
        
    def _reciprocal_rms(self, x: torch.Tensor):
        '''
        rms(x) = sqrt(mean(x^2))
        reciprocal_rms(x) = x / rms(x)
        '''
        
        # (B, seq_len, dim) * (B,. seq_len, 1) -> (B, seq_len, dim)
        return x * torch.rsqrt(x.pow(2).mean(dim=self.dim, keepdim=True) + self.eps)
    
    def forward(self, x:torch.Tensor):        
        # (dim) * (B, seq_len, dim) -> (B, seq_len, dim)
        return self.weight * self._reciprocal_rms(x).type_as(x)

In [5]:
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layer: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocal_size: int = -1
    multiple_of = int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-6
    
    # needed for KV cache
    max_batch_size: int = 32
    max_seq_len: int = 2048
    
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    '''
    x: (B, seq_len, dim) -> (B, seq_len, n_rep, dim)
    '''
    
    batch_size, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    else:
        return (
            x[:, :, :, None, :]
            .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
            .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
        )

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        
        # calculate n_heads of Grouped Attention with KV cache
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_heads_q = args.n_heads
        self.n_rep = self.n_heads_q // self.n_kv_heads
        self.head_dim = args.dim // self.n_heads_q      
        
        self.wq = nn.Linear(args.dim, self.n_heads_q * self.head_dim, bias=False)       
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.dim, args.dim, bias=False) 
        
        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim), device=args.device)
        
    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        batch_size, seq_len, _ = x.size() # (B, seq_len, dim)
        
        xq = self.wq(x) # (B, seq_len, dim) -> (B, seq_len, n_heads_q * head_dim)
        xk = self.wk(x) # (B, seq_len, dim) -> (B, seq_len, n_heads_kv * head_dim)
        xv = self.wv(x) # (B, seq_len, dim) -> (B, seq_len, n_heads_kv * head_dim)
        
        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim) # (B, seq_len, n_heads_q * head_dim) -> (B, seq_len, n_heads_q, head_dim)
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim) # (B, seq_len, n_heads_kv * head_dim) -> (B, seq_len, n_heads_kv, head_dim) 
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim) # (B, seq_len, n_heads_kv * head_dim) -> (B, seq_len, n_heads_kv, head_dim)
        
        # position encoding for query, key
        xq = apply_rotary_embeddings(xq, freqs_complex, x.device)
        xk = apply_rotary_embeddings(xk, freqs_complex, x.device)
        
        # Replace the 4entry in the cache with the new values
        self.cache_k[:batch_size, start_pos:start_pos+seq_len] = xk
        self.cache_v[:batch_size, start_pos:start_pos+seq_len] = xv
        
        # (B, seq_len_kv, n_heads_kv, head_dim) 
        keys = self.cache_k[:batch_size, :start_pos+seq_len]
        values = self.cache_v[:batch_size, :start_pos+seq_len]
        
        # repeat the heads of K and V to match the number of heads of Q
        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)
        
        # (B, seq_len, n_heads_q, head_dim) -> (B, n_heads_q, seq_len, head_dim)
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # (B, n_heads_q, seq_len, head_dim) * (B, n_heads_q, head_dim, seq_len_kv) -> (B, n_heads_q, seq_len, seq_len_kv)
        scores = torch.matmul(xq, keys.transpose(2, 3) / torch.sqrt(self.head_dim))
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        
        # (B, n_heads_q, seq_len, seq_len_kv) * (B, n_heads_q, seq_len_kv, head_dim) -> (B, n_heads_q, seq_len, head_dim)
        out = torch.matmul(scores, values)
        
        # (B, n_heads_q, seq_len, head_dim) -> (B, seq_len, n_heads_q, head_dim) -> (B, seq_len, dim)
        out = (out.transpose(1, 2).continguous().view(batch_size, seq_len, self.n_heads_q * self.head_dim))
        # (B, seq_len, dim) -> (B, seq_len, dim)
        out = self.wo(out) 
        return out 