In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass
@dataclass
class GPTConfig:
    n_block: int
    n_layer: int
    n_head: int
    n_embed: int
    n_vocab: int = 50257
    n_kv_head: int = 2
    model_type: str = 'buddygpt'
    pad_token_id=None,
    bos_token_id=None,
    eos_token_id=50256,
    keys_to_ignore_at_inference = ["past_key_values"]



In [2]:
# # rope
# def precompute_freqs_cis(dim, max_seq_len=2048, theta=10000.0):
#     freqs = theta ** -(torch.arange(0, dim, 2)[:dim//2].float() / dim)
#     t = torch.arange(max_seq_len)
#     freqs = torch.outer(t, freqs) # m * \theta
#     # freqs = t * freqs
#     freqs = torch.polar(torch.ones_like(freqs), freqs) # cos(m * \theta) + jsin(m * \theta)
#     return 

# # 2. 为广播 reshape freqs
# def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
#     if freqs_cis.shape[0] > x.shape[1]:
#         freqs_cis = freqs_cis[:x.shape[1]]
#     assert freqs_cis.shape == (x.shape[1], x.shape[-1])
#     shape = [1 if i != 1 and i != x.ndim - 1 else x.shape[i] for i in range(x.ndim)]
#     return freqs_cis.view(*shape).to(x.device)

# def apply_rotary_emb(q, k, freqs):
#     xq = torch.view_as_complex(q.view(*q.shape[:-1], -1, 2)) # batch, seq_len, n_head, dim//2
#     xk = torch.view_as_complex(k.view(*k.shape[:-1], -1, 2)) # batch, seq_len, n_head, dim//2
    
#     freqs_cis = reshape_for_broadcast(freqs, xq) # freqs_cis.shape = (1,seq_len,1,dim//2)

#     xq_out = torch.view_as_real(xq * freqs_cis).flatten(3) # batch, seq_len, n_head, dim
#     xk_out = torch.view_as_real(xk * freqs_cis).flatten(3) # batch, seq_len, n_head, dim

#     return xq_out.type_as(q), xk_out.type_as(k)
    

In [3]:
# rope = precompute_freqs_cis(dim=64)
# q = torch.randn(2, 4, 12, 64)  # B=2, H=4, T=12, D=64
# k = torch.randn(2, 4, 12, 64)
# xq, xk = apply_rotary_emb(q, k, rope)
# xq.shape, xk.shape

In [4]:
class RotaryEmbedding(nn.Module):
    def __precompute_freqs_cis(self, dim, max_seq_len, theta):
        assert dim%2 == 0
        freqs = theta ** -(torch.arange(0, dim ,2).float() / dim)
        t = torch.arange(max_seq_len)
        freqs = torch.outer(t, freqs) # (seq_len, dim/2)
        freqs = torch.polar(torch.ones_like(freqs), freqs) # cos(m*\theta) + jsin(m*\theat)
        return freqs
        
    def __init__(self, dim, max_seq_len=2048, theta=10000.0):
        super().__init__()
        self.dim = dim
        self.freqs = self.__precompute_freqs_cis(dim, max_seq_len, theta)

    def apply_rotary_emb(self, q, k=None):
        seq_len, dim = q.size(1), q.size(-1) # batch, n_head, seq_len, n_embed
        freqs_cis = self.freqs[None, :seq_len, None, :dim//2].contiguous()
        xq = torch.view_as_complex(q.view(*q.shape[:-1], -1, 2))
        xq_out = torch.view_as_real(xq * freqs_cis).flatten(3)
        if k is not None:
            xk = torch.view_as_complex(k.view(*k.shape[:-1], -1, 2))
            xk_out = torch.view_as_real(xk * freqs_cis).flatten(3)
            return xq_out, xk_out
        else:
            return xq_out
                                   
        

In [5]:
rope = RotaryEmbedding(dim=64)
q = torch.randn(2, 12, 4, 64) # batch, n_head, seq_len, n_embed
k = torch.randn(2, 12, 4, 64)
q1,k1 = rope.apply_rotary_emb(q, k)
q1.shape,k1.shape

(torch.Size([2, 12, 4, 64]), torch.Size([2, 12, 4, 64]))

In [6]:
class GQA(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head
        self.n_embed = config.n_embed
        self.head_dim = self.n_embed // self.n_head
        self.kv_head_dim = self.head_dim * self.n_kv_head
        self.repeat_factor = self.n_head // self.n_kv_head
        self.q_proj = nn.Linear(self.n_embed, self.n_embed)
        self.k_proj = nn.Linear(self.n_embed, self.kv_head_dim)
        self.v_proj = nn.Linear(self.n_embed, self.kv_head_dim)
        self.out_proj = nn.Linear(self.n_embed, self.n_embed)
        self.rope = RotaryEmbedding(self.n_embed)
        self.register_buffer('tril', torch.tril(torch.ones(config.n_block, config.n_block)).view(1,1,config.n_block, config.n_block))

    def forward(self, x):
        B, T, _ = x.shape
        q = self.q_proj(x).view(B, T, self.n_head, -1) # B, T, n_head, n_embed
        k = self.k_proj(x).view(B, T, self.n_kv_head, -1) # B, T, n_kv_head, n_embed
        v = self.v_proj(x).view(B, T, self.n_kv_head, -1) # B, T, n_kv_head, n_embed

        xq, xk = self.rope.apply_rotary_emb(q), self.rope.apply_rotary_emb(k)

        xq = xq.transpose(1, 2) # B, n_head, T, n_embed
        xk = xk.transpose(1, 2) # B, n_kv_head, T, n_embed
        xv = v.transpose(1, 2) # B, n_kv_head, T, n_embed

        xk = xk.repeat_interleave(self.repeat_factor, dim=1) # B, n_head, T, n_embed
        xv = xv.repeat_interleave(self.repeat_factor, dim=1) # B, n_head, T, n_embed

        if FLASH:
            o_attn = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        else:
            qk = torch.matmul(xq, xk.transpose(-2, -1))
            qk = qk.masked_fill(self.tril[:,:,:T,:T]==0, float('-inf'))
            qk = F.softmax(qk, dim=-1) * (self.n_embed ** -0.5)
            o_attn = qk @ xv # B, n_head, T, n_embed
        o_attn = o_attn.transpose(1, 2).contiguous().view(B, T, -1)
        return self.out_proj(o_attn)
        

In [7]:
FLASH=1
config = GPTConfig(n_block=512, n_layer=2, n_head=8, n_embed=64, n_kv_head=2)
gqa = GQA(config)  
x = torch.randn(2, 12, 64) 
gqa(x).shape

torch.Size([2, 12, 64])

In [11]:
x0, x1 = x.chunk(2, dim=-1)
x0.shape,x1.shape

(torch.Size([2, 12, 32]), torch.Size([2, 12, 32]))

In [12]:
(F.silu(x0) * x1).shape

torch.Size([2, 12, 32])