In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

![rope](https://ar5iv.labs.arxiv.org/html/2104.09864/assets/x1.png)

In [3]:
def precompute_freqs_cis(dim, end, theta=10000.0):
    freqs = theta ** -(torch.arange(0, dim, 2)[:dim//2].float() / dim)
    t= torch.arange(end)
    freqs = torch.outer(t, freqs) # m * \theta
    # freqs= t * freqs
    freqs = torch.polar(torch.ones_like(freqs), freqs) # cos(m * \theta) + jsin(m * \theta)
    return freqs

In [9]:
def reshape_for_broadcast(freqs, x):
    return freqs.view(1, x.shape[1], 1, x.shape[-1])

def apply_rotary_emb(q, k, freqs):
    xq = torch.view_as_complex(q.view(*q.shape[:-1], -1, 2)) # batch, seq_len, n_head, dim//2
    xk = torch.view_as_complex(k.view(*k.shape[:-1], -1, 2)) # batch, seq_len, n_head, dim//2
    freqs_cis = freqs.view(1, xq.shape[1], 1, xq.shape[-1]) # 1, seq_len, 1, dim//2

    xq_out = torch.view_as_real(xq * freqs_cis).flatten(3) # batch, seq_len, n_head, dim
    xk_out = torch.view_as_real(xk * freqs_cis).flatten(3) # batch, seq_len, n_head, dim

    return xq_out.type_as(q), xk_out.type_as(k)

In [23]:
q = torch.randn(2, 10, 2, 5)
freqs = precompute_freqs_cis(10, 10).view(1, 10, 1, 5)
# print(q, freqs)
# (q * freqs)
# 9.7774e-01 * (1.0000+0.0000e+00j) == 9.7774e-01+0.0000e+00j

In [11]:
freqs = precompute_freqs_cis(10, 10)
reshape_for_broadcast(freqs, torch.randn(2,10,2,5)).shape

torch.Size([1, 10, 1, 5])

In [6]:
torch.randn(2,3,4).ndim

3

![GQA](https://wdndev.github.io/llm_interview_note/02.%E5%A4%A7%E8%AF%AD%E8%A8%80%E6%A8%A1%E5%9E%8B%E6%9E%B6%E6%9E%84/llama%202%E4%BB%A3%E7%A0%81%E8%AF%A6%E8%A7%A3/image/image_XJgG9to7qe.png)

In [40]:
def repeat_kv(x, n_rep):
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return x.unsqueeze(3).expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads*n_rep, head_dim)

In [42]:
x = torch.randn(2,3,2,5)
repeat_kv(x, 2).shape

torch.Size([2, 3, 4, 5])

In [19]:
10000.0** -(torch.arange(0, 100, 2).float() / 100)

tensor([1.0000e+00, 8.3176e-01, 6.9183e-01, 5.7544e-01, 4.7863e-01, 3.9811e-01,
        3.3113e-01, 2.7542e-01, 2.2909e-01, 1.9055e-01, 1.5849e-01, 1.3183e-01,
        1.0965e-01, 9.1201e-02, 7.5858e-02, 6.3096e-02, 5.2481e-02, 4.3652e-02,
        3.6308e-02, 3.0200e-02, 2.5119e-02, 2.0893e-02, 1.7378e-02, 1.4454e-02,
        1.2023e-02, 1.0000e-02, 8.3176e-03, 6.9183e-03, 5.7544e-03, 4.7863e-03,
        3.9811e-03, 3.3113e-03, 2.7542e-03, 2.2909e-03, 1.9055e-03, 1.5849e-03,
        1.3183e-03, 1.0965e-03, 9.1201e-04, 7.5858e-04, 6.3096e-04, 5.2481e-04,
        4.3652e-04, 3.6308e-04, 3.0200e-04, 2.5119e-04, 2.0893e-04, 1.7378e-04,
        1.4454e-04, 1.2023e-04])

In [14]:
freqs = 1.0 / (10000.0 ** (torch.arange(0, 100, 2)[: (100 // 2)].float() / 100))
freqs

tensor([1.0000e+00, 8.3176e-01, 6.9183e-01, 5.7544e-01, 4.7863e-01, 3.9811e-01,
        3.3113e-01, 2.7542e-01, 2.2909e-01, 1.9055e-01, 1.5849e-01, 1.3183e-01,
        1.0965e-01, 9.1201e-02, 7.5858e-02, 6.3096e-02, 5.2481e-02, 4.3652e-02,
        3.6308e-02, 3.0200e-02, 2.5119e-02, 2.0893e-02, 1.7378e-02, 1.4454e-02,
        1.2023e-02, 1.0000e-02, 8.3176e-03, 6.9183e-03, 5.7544e-03, 4.7863e-03,
        3.9811e-03, 3.3113e-03, 2.7542e-03, 2.2909e-03, 1.9055e-03, 1.5849e-03,
        1.3183e-03, 1.0965e-03, 9.1201e-04, 7.5858e-04, 6.3096e-04, 5.2481e-04,
        4.3652e-04, 3.6308e-04, 3.0200e-04, 2.5119e-04, 2.0893e-04, 1.7378e-04,
        1.4454e-04, 1.2023e-04])