In [1]:
from transformers import PretrainedConfig
from typing import List

In [6]:
import math
import struct
import inspect
import time
from typing import Any, Optional, Tuple, List
import numpy as np
from torch import nn 
from transformers import PreTrainedModel 
import torch


In [7]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1,keepdim=True)+self.eps)).type_as(x)


In [8]:
def precompute_pos_cis(dim: int, end: int = int(32*1024), theta: float = 1e6):
    freq = 1.0 / (theta ** (torch.arange(0,dim,2)[: (dim//2)].float() / dim))

    t = torch.arange(end, device=freq.device)
    freqs = torch.outer(t,freqs).float()
    pos_cis = torch.polar(torch.ones_like(freqs),freqs)
    return pos_cis 


In [None]:
def apply_rotray_emb(xq, xk, pos_cis):
    def unite_shape(pos_cis,x):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert pos_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i==1 or i == ndim -1 else 1 for i, d in enumerate(x.shape)]
        return pos_cis.view(*shape)
    
    xq_ = troch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1,2))
    xk_ = troch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1,2))
    pos_cis  = unite_shape(pos_cis, xq_)
    xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [None]:
def repeat_kv(x: torch.Tensor, n_rep: int):
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:,:,None,:].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs,slen,n_kv_heads * n_rep, head_dim)
    )