In [1]:
from transformers import Qwen2ForCausalLM
import torch.nn.functional as F
import torch
from torch import nn, Tensor
from dataclasses import dataclass
import numpy as np

In [2]:
# see: https://charent.github.io/p/llama2模型结构方面的改进/
class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float=1e-6):
        super().__init__()
        
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eposilon = eps
    
    def forward(self, hidden_states: Tensor):
        dtype = hidden_states.dtype
        hidden_states = hidden_states.to(dtype=torch.float32)

        mean_square = hidden_states.pow(2).mean(dim=-1, keepdim=True)

        # 注意，这里不是sqrt， rsqrt(x) = 1 / sqrt(x)
        hidden_states = hidden_states * torch.rsqrt(mean_square + self.eposilon)
        hidden_states = self.weight * hidden_states

        return hidden_states.to(dtype=dtype)


In [3]:
def rotate_half(x: Tensor):
    embedding_dim_half = x.shape[-1] // 2
    x1 = x[..., : embedding_dim_half]
    x2 = x[..., embedding_dim_half: ]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotay_pos_emb(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, position_ids: Tensor, unsqueeze_dim: int=1):
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    
    q_embedding = (q * cos) + (rotate_half(q) * sin)
    k_embdedding = (q * cos ) + (rotate_half(k) * sin)

    return q_embedding, k_embdedding

class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, max_position_embedding: int=1024, base:int = 1_0000, device: torch.device=None):
        super().__init__()
        
        self.hidden_size = dim
        self.max_position_embedding = max_position_embedding
        self.base = base

        inv_freq = 1.0 / (self.base ** (
            torch.arange(0, dim, 2, dtype=torch.int64).float().to(device=device) / dim
        ))

        self.register_buffer('inv_freq', inv_freq, persistent=False)
        self.max_seq_len_cached = None
    
    def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
        self.max_seq_len_cached = seq_len

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        new_freqs = torch.outer(t, self.inv_freq.to(device=t.device))
        embedding = torch.cat((new_freqs, new_freqs), dim=-1)
        self.register_buffer('cos_cached', embedding.sin().to(dtype=dtype), persistent=False)
        self.register_buffer('sin_cached', embedding.cos().to(dtype=dtype), persistent=False)
    
    @torch.no_grad()
    def forward(self, x: Tensor, seq_len: Tensor|int):
        if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[: seq_len].to(dtype=x.dtype),
            self.sin_cached[: seq_len].to(dtype=x.dtype),
        )

In [4]:
@dataclass
class MLAConfig:
    hidden_size: int = 1024
    num_heads: int = 8
    assert hidden_size % num_heads == 0

    max_position_embeddings: int = 1024
    rope_theta: float = 10_0000.0
    attention_dropout: float = 0.1

    q_lora_rank: int = 128
    
    qk_rope_head_dim: int = 8
    qk_nope_head_dim: int = 24
    q_head_dim: int = qk_rope_head_dim + qk_nope_head_dim

    assert  hidden_size % q_head_dim == 0

    kv_lora_rank: int = 32
    v_head_dim: int = 16
    
    assert hidden_size % v_head_dim == 0
    attention_bias: bool=False


In [5]:
class MLA(nn.Module):
    def __init__(self, config: MLAConfig):
        super().__init__()
        
        self.config = config
        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim

        # q
        self.q_down_proj = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.q_lora_rank,
            bias=config.attention_bias,
        )
        self.q_down_layernorm = RMSNorm(config.q_lora_rank)

        self.q_up_proj = nn.Linear(
            in_features=config.q_lora_rank,
            # out_features并不是hidden_size，最终需要拆分为需要应用rope和nope两部分
            out_features=config.num_heads * self.q_head_dim, 
            bias=config.attention_bias,
        )
        self.q_up_layernorm = RMSNorm(config.q_lora_rank)


        # kv
        self.kv_down_proj = nn.Linear(
            in_features=config.hidden_size,
            out_features=config.kv_lora_rank + config.qk_rope_head_dim,
            bias=False,
        )

        self.kv_down_layernorm = RMSNorm(config.kv_lora_rank)

        self.kv_up_proj = nn.Linear(
            in_features=config.kv_lora_rank,
            out_features=config.num_heads * (config.qk_nope_head_dim + config.v_head_dim),
            bias=False,
        )

        self.out_proj = nn.Linear(
            in_features=config.num_heads * config.v_head_dim,
            out_features=config.hidden_size,
            bias=config.attention_bias,
        )
        
        self.rotary_emb = RotaryEmbedding(
            dim=config.qk_rope_head_dim,
            max_position_embedding=config.max_position_embeddings,
            base=config.rope_theta,
        )

        self.atten_factor = 1.0 / np.sqrt(config.qk_nope_head_dim)

    def forward(self, 
                hidden_states: Tensor,
                attention_mask: Tensor=None,
                position_ids: Tensor=None,
            ):
        
        bs, q_len, _ = hidden_states.shape
        config = self.config

        # 1. q
        q: Tensor = self.q_down_proj(hidden_states)
        q = self.q_down_layernorm(q) # [bs, sql_len, q_lora_rank]
        q = self.q_up_proj(q) # [bs, sql_len, num_heads * (qk_nope_head_dim + qk_rope_head_dim)]
        q = q.reshape(bs, q_len, config.num_heads, self.q_head_dim)

        q_nope, q_pe = torch.split(
            q,
            [config.qk_nope_head_dim, config.qk_rope_head_dim],
            dim=-1,
        )

        # 2. kv
        downed_kv: Tensor = self.kv_down_proj(hidden_states)
        downed_kv, k_pe = torch.split(
            downed_kv,
            [config.kv_lora_rank, config.qk_rope_head_dim],
            dim=-1,
        )
        k_pe = k_pe.reshape(bs, q_len, 1, config.qk_rope_head_dim)

        # [bs, q_len, kv_lora_rank]
        kv: Tensor = self.kv_down_layernorm(downed_kv)

        ## TODO kv cache here

        kv = self.kv_up_proj(kv) # [bs, q_len, num_heads * (qk_nope_head_dim + v_head_dim)]
        kv = kv.reshape(bs, q_len, config.num_heads, config.qk_nope_head_dim + config.v_head_dim)

        k_nope, value_states = torch.split(
            kv,
            [config.qk_nope_head_dim, config.v_head_dim],
            dim=-1,
        )

        # 3. q, k 应用旋转位置编码
        cos, sin = self.rotary_emb(x=value_states, seq_len=q_len)
        q_pe, k_pe = apply_rotay_pos_emb(q=q_pe, k=k_pe, cos=cos, sin=sin, position_ids=position_ids)

        query_states = torch.empty(
            bs, q_len, config.num_heads, self.q_head_dim,
            device=k_pe.device,
        )

        query_states[:, :, :, : config.qk_nope_head_dim] = q_nope
        query_states[:, :, :, config.qk_nope_head_dim: ] = q_pe

        key_states = torch.empty(
            bs, q_len, config.num_heads, self.q_head_dim,
            device=k_pe.device,
        )
        key_states[:, :, :, : config.qk_nope_head_dim] = k_nope
        key_states[:, :, :, config.qk_nope_head_dim: ] = k_pe

        # 5. 注意力分数计算
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        atten_weights = torch.matmul(query_states, key_states.transpose(2, 3))
        atten_weights *= self.atten_factor


        if attention_mask is not None:
            atten_weights = torch.masked_fill(
                atten_weights,
                attention_mask == 0,
                -torch.inf
            )
        
            
        # softmax
        atten_weights = F.softmax(atten_weights, dim=-1, dtype=torch.float32).to(dtype=query_states.dtype)
        atten_weights = F.dropout(atten_weights, p=config.attention_dropout, training=self.training)

        # output
        atten_output = torch.matmul(atten_weights, value_states)
        atten_output = atten_output.reshape(bs, q_len, -1)

        atten_output = self.out_proj(atten_output)

        return atten_output, atten_weights


def test():
    config = MLAConfig()
    model = MLA(config)

    seq_len = 1000
    batch_size = 2
    x = torch.rand((batch_size, seq_len, config.hidden_size))
    position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)

    atten_output, atten_weights = model(x, position_ids=position_ids)

    print(atten_output.shape)
    print(atten_weights.shape)

test()

torch.Size([2, 1000, 1024])
torch.Size([2, 8, 1000, 1000])
