# 1-Model

如今的主流大模型的架构大致分为稠密模型(Dense)和混合专家模型(MOE), 在这一板块，我们将围绕 MiniMind 系列模型的源代码展开学习。

## MiniMind Dense Model

MiniMind 系列模型在 Llama 3.1 的基础上设计，基于经典的 Transformer Deocder-Only 架构，其主要特点如下：
1. 采用 Pre-Norm 归一化方法，使用 RMSNorm 归一化函数。
2. 使用 SwiGLU 激活函数。
3. 旋转位置嵌入 (RoPE)

作者提供了对其 MiniMind 模型结构的可视化：

![](../images/LLM-structure.png)

我们在上一节中完成了 Tokenizer 的学习，这一节我们关注 LLM 模型架构的具体实现。

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Root Mean Square Layer Normalization (RMSNorm)

RMSNorm 是对 LayerNorm 的一个改进, 没有做 re-center 操作（移除了均值项），可以看作 LayerNorm 在均值为零时的特例，使用平方根均值归一化降低噪声影响。

- Layer Norm

$$y = \frac{x-E(x)}{\sqrt{Var(x) + \epsilon}} * \gamma + \beta$$

假设输入张量形状为 (batch_size, sequence_length, embedding_dim)，层归一化会对 embedding_dim 维度进行归一化操作，其中， $\epsilon$ 是一个超参数，用于防止非法运算（分母为零）， $\gamma$, $\beta$ 均为可学习参数。

- RMS Norm

$$a_i=\frac{a_i}{RMS(a) g_i}, where RMS(a) = \sqrt{\frac{1}{n}\sum^n_{i=1}a^2_i}.$$

假设输入张量形状为 (batch_size, sequence_length, embedding_dim)，RMS归一化同样对 embedding_dim 维度进行归一化，其中，$g_i$ 为可学习参数，用于对均方根归一化结果进行加权。

In [2]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        return self.weight * (x.float() * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)).type_as(x)

### Rotary Position Embedding, RoPE

旋转位置编码是一种能将相对位置信息集成到 self-attentio 中, 进而提升 transformer 架构性能的位置编码方式, 和绝对位置编码相比, RoPE 具有很好的外推性, 是目前的主流位置编码方式。

外推性的解释, 通俗来说就是训练的时候限制了 512 的上下文长度，那么推理时如果面对超过该长度的文本，LLM 可能无法正确处理。

- 绝对位置编码

$$f_{t:t\in\{q,k,v\}}(\boldsymbol{x}_i,i):=\boldsymbol{W}_{t:t\in\{q,k,v\}}(\boldsymbol{x}_i+\boldsymbol{p}_i)$$

其中编码向量 $p_i$ 的计算使用 Sinusoidal 函数：

$$\boldsymbol{p}_{i,2t}=\sin\left(k/1000^{2t/d}\right), \boldsymbol{p}_{i,2t+1}=\cos\left(k/1000^{2t/d}\right)$$

正如其名，绝对位置编码只考虑了输入序列中的绝对位置关系，对于 token 之间的相对信息则没有纳入考虑。

- 旋转位置编码

假定 query 和 key 的内积操作可以被函数 g 表示（有点像核函数），该函数 g 的输入是词嵌入向量 $x_m, x_n$ 和它们之间的相对位置 $m-n$:

$$<f_q(x_m ,m), f_k(x_n, n)>=g(x_m, x_n, m, n)$$

旋转位置编码就是找到一个使上式成立的位置编码方式。

$$f_{\{q,k\}}\left(\boldsymbol{x}_m,m\right)=\boldsymbol{R}_{\Theta,m}^d\boldsymbol{W}_{\{q,k\}}\boldsymbol{x}_m$$

$$\boldsymbol{R}_{\Theta,m}^d=\underbrace{\begin{pmatrix}\cos m\theta_0&-\sin m\theta_0&0&0&\cdots&0&0\\\sin m\theta_0&\cos m\theta_0&0&0&\cdots&0&0\\0&0&\cos m\theta_1&-\sin m\theta_1&\cdots&0&0\\0&0&\sin m\theta_1&\cos m\theta_1&\cdots&0&0\\\vdots&\vdots&\vdots&\vdots&\ddots&\vdots&\vdots\\0&0&0&0&\cdots&\cos m\theta_{d/2-1}&-\sin m\theta_{d/2-1}&-\sin m\theta_{d/2-1}\end{pmatrix}}_{\boldsymbol{W}_m}(13)$$

$$\Theta=\left\{\theta_i=10000^{-2(i-1)/d},i\in[1,2,\ldots,d/2]\right\}$$

由于 $\boldsymbol{R}_{\Theta,m}^d$ 的稀疏性，直接使用矩阵乘法会浪费算力，因此采用下述方式实现：

$$\boldsymbol{R}_{\Theta,m}^{d}\boldsymbol{x}=\begin{pmatrix}x_{0}\\x_{1}\\x_{2}\\x_{3}\\\vdots\\x_{d-2}\\x_{d-1}\end{pmatrix}\otimes\begin{pmatrix}\cos m\theta_{0}\\\cos m\theta_{0}\\\cos m\theta_{1}\\\cos m\theta_{1}\\\vdots\\\cos m\theta_{d/2-1}\\\cos m\theta_{d/2-1}\end{pmatrix}+\begin{pmatrix}-x_{1}\\x_{0}\\-x_{3}\\x_{2}\\\vdots\\-x_{d-1}\\x_{d-2}\end{pmatrix}\otimes\begin{pmatrix}\sin m\theta_{0}\\\sin m\theta_{0}\\\sin m\theta_{1}\\\sin m\theta_{1}\\\vdots\\\sin m\theta_{d/2-1}\\\sin m\theta_{d/2-1}\end{pmatrix}
$$

In [3]:
# compute pe in complex form
def precompute_pos_cis(dim: int, end: int=int(32 * 1024), theta: float=1e6):
    """
    compute pe frequency in complex form.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    pos_cis = torch.polar(torch.ones_like(freqs), freqs)
    return pos_cis

def apply_rotary_emb(xq, xk, pos_cis):
    def unite_shape(pos_cis, x):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert pos_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        return pos_cis.view(*shape)

    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    print('original shape xq: {}'.format(xq.shape)) # 查看 xq 的 shape
    print('ajusted shape xq: {}'.format(xq_.shape)) # 查看经过复数表示后 xq 的 shape
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    pos_cis = unite_shape(pos_cis, xq_)
    print('ajusted shape pos_cis： {}'.format(pos_cis.shape)) # 查看位置编码旋转角的形状
    xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [4]:
xq, xk = torch.randn((2, 16, 4, 64)), torch.randn((2, 16, 4, 64)) # (batch_size, sequence_length, num_heads, head_dim)
pos_cis = precompute_pos_cis(64, 16) # 计算旋转位置编码的旋转角（复数表示）
pos_cis.shape, pos_cis[0, 0]

(torch.Size([16, 32]), tensor(1.+0.j))

In [5]:
xq_rope, xk_rope = apply_rotary_emb(xq, xk, pos_cis)

original shape xq: torch.Size([2, 16, 4, 64])
ajusted shape xq: torch.Size([2, 16, 4, 32])
ajusted shape pos_cis： torch.Size([1, 16, 1, 32])


In [6]:
xq_rope.shape, xk_rope.shape

(torch.Size([2, 16, 4, 64]), torch.Size([2, 16, 4, 64]))

### Attention

Attention 是 Transformer 架构中的重要机制，它可以捕获长序列间的依赖关系，并依据注意力得分进行重要性建模。

在 MiniMind 模型中，Attention Block 涉及几个重要机制。

1. GQA
2. KV Cache
4. Pre-RMSNorm & RoPE
5. SwiGLU
6. FFN

- GQA

Group Querey Attention (GQA) 是对多头自注意力机制的扩展，通过提供计算效率和模型表达能力的灵活权衡，实现了查询头的分组。

具体来说，GQA 将查询头分为 G 个组，每个组共享一个公共的键和值。

![img](./demo/gqa.png)

- KV Cache

KV Cache 能够有效压缩大模型推理时的显存占用，在推理时，前面生成的字符不需要与后面的字符产生 attention，从而使得前面已经计算的 K 和 V 可以缓存起来。

- SwiGLU

$$\text{SwiGLU}(x,W,V,b,c)=\text{Swish}_1(xW+b)\otimes(xV+c)$$

In [7]:
from demo.LMConfig import LMConfig
from typing import Any, Optional, Tuple, List
import torch.nn as nn
import math

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

class Attention(nn.Module):
    def __init__(self, args: LMConfig):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % self.n_kv_heads == 0
        self.n_local_heads = args.n_heads
        self.n_local_kv_heads = args.n_kv_heads
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
        mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
        mask = torch.triu(mask, diagonal=1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self,
               x: torch.Tensor,
               pos_cis: torch.Tensor,
               past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
               use_cache=False):
        bsz, seq_len, _ = x.shape
        ############## Forward QKV & RoPE ##############
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
        xq, xv = apply_rotary_emb(xq, xk, pos_cis)
        ################### KV Cache ###################
        if past_key_value is not None:
            xk = torch.cat([past_key_value[0], xk], dim=1)
            xv = torch.cat([past_key_value[1], xv], dim=1)
        past_kv = (xk, xv) if use_cache else None
        xq, xk, xv = (
            xq.transpose(1, 2),
            repeat_kv(xk, self.n_rep).transpose(1, 2),
            repeat_kv(xv, self.n_rep).transpose(1, 2)
        )
        ############ Scaled Dot Production #############
        if self.flash and seq_len != 1:
            dropout_p = self.dropout if self.training else 0.0
            output = F.scaled_dot_product_attention(
                xq, xk, xv,
                attn_mask=None,
                dropout_p=dropout_p,
                is_causal=True
            )
        else:
            scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
            scores += self.mask[:, :, :seq_len, :seq_len]
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            scores = self.attn_dropout(scores)
            output = scores @ xv
        ################################################
        output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
        output = self.resid_dropout(self.wo(output))
        return output, past_kv

In [8]:
attn = Attention(LMConfig())
x = torch.randn((4, 16, 512))
print(x.shape)
pos_cis = precompute_pos_cis(64, 16)
print(pos_cis.shape)
output, past_kv = attn(x, pos_cis=pos_cis, use_cache=True)
print(output.shape, past_kv[0].shape, past_kv[1].shape)

torch.Size([4, 16, 512])
torch.Size([16, 32])
original shape xq: torch.Size([4, 16, 8, 64])
ajusted shape xq: torch.Size([4, 16, 8, 32])
ajusted shape pos_cis： torch.Size([1, 16, 1, 32])
torch.Size([4, 16, 512]) torch.Size([4, 16, 2, 64]) torch.Size([4, 16, 2, 64])


### FeedForward Network

前向传播神经网络接收来自注意力层的输出，并对其做进一步的线性变换，以捕获更复杂的特征和表示。

In [9]:
class FeedForward(nn.Module):
    def __init__(self, config: LMConfig):
        super().__init__()
        if config.hidden_dim is None:
            hidden_dim = 4 * config.dim
            hidden_dim = int(2 * hidden_dim / 3)
            config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
        self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
        self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.silu(self.w1(x) * self.w3(x))))

In [10]:
ffn = FeedForward(LMConfig())
x = torch.randn((4, 16, 512))
output = ffn(x)
output.shape

torch.Size([4, 16, 512])

### MiniMind Block

到目前为止，，已经完成了 Attention Layer 和 FeedForward Layer 的构建，所有必须的组件都已经具备，我们着手构建一个 MiniMind Block

In [11]:
class MiniMindBlock(nn.Module):
    def __init__(self, layer_id: int, config: LMConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.dim = config.dim
        self.head_dim = config.dim // config.n_heads # number of GQA heads
        self.attention = Attention(config)

        self.layer_id = layer_id
        self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.feed_forward = FeedForward(config)

    def forward(self, x, pos_cis, past_key_value=None, use_cache=False):
        h_attn, past_kv = self.attention(
            self.attention_norm(x), # pre-normed x
            pos_cis,
            past_key_value=past_key_value,
            use_cache=use_cache
        )
        h = x + h_attn # residual connection
        out = h + self.feed_forward(self.ffn_norm(h)) # feed forward + residual connection
        return out, past_kv

In [12]:
miniblock = MiniMindBlock(1, LMConfig())
x = torch.randn((4, 16, 512))
pos_cis = precompute_pos_cis(64, 16)
out, past_kv = miniblock(x, pos_cis, use_cache=True)
out.shape, past_kv[0].shape, past_kv[1].shape

original shape xq: torch.Size([4, 16, 8, 64])
ajusted shape xq: torch.Size([4, 16, 8, 32])
ajusted shape pos_cis： torch.Size([1, 16, 1, 32])


(torch.Size([4, 16, 512]),
 torch.Size([4, 16, 2, 64]),
 torch.Size([4, 16, 2, 64]))

### MiniMindLM

以 MiniMind Block 为基本组件，我们可以对该 LLM 进行最后组装！

In [13]:
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast

class MOEFeedForward():
    """Place holder"""
    pass

class MiniMindLM(PreTrainedModel):
    """Notice: To simplify notebook run, set n_layers to 2"""
    config_class = LMConfig

    def __init__(self, params: LMConfig = None):
        self.params = params or LMConfig()
        super().__init__(self.params)
        self.vocab_size, self.n_layers = params.vocab_size, params.n_layers
        # 映射：词表维度 -> 嵌入维度
        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
        self.dropout = nn.Dropout(params.dropout)
        self.layers = nn.ModuleList([MiniMindBlock(1, params) for l in range(self.n_layers)])
        self.norm = RMSNorm(params.dim, eps = params.norm_eps)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
        self.tok_embeddings.weight = self.output.weight
        self.register_buffer(
            "pos_cis",
            precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),
            persistent=False
        )
        self.OUT = CausalLMOutputWithPast()

    def forward(self,
               input_ids: Optional[torch.Tensor] = None,
               past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
               use_cache: bool = False,
               **args):
        past_key_values = past_key_values or [None] * len(self.layers) # save history KV
        start_pos = args.get('start_pos', 0)
        h = self.dropout(self.tok_embeddings(input_ids))
        pos_cis = self.pos_cis[start_pos: start_pos + input_ids.size(1)]
        past_kvs = []
        for l, layer in enumerate(self.layers):
            h, past_kv = layer(
                h, pos_cis,
                past_key_value=past_key_values[l],
                use_cache=use_cache
            )
            past_kvs.append(past_kv)
        logits = self.output(self.norm(h))
        aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward)) # MOE 辅助损失
        self.OUT.__setitem__('logits', logits)
        self.OUT.__setitem__('aux_loss', aux_loss)
        self.OUT.__setitem__('past_key_value', past_kvs)
        return self.OUT

    @torch.inference_mode()
    def generate(self, input_ids, eos_token_id=2, max_new_tokens=512, temperature=0.75, top_p=0.90,
                stream=False, rp=1, use_cache=True, pad_token_id=0, **args):
        # 流式生成 （返回生成器，自由控制输出）
        if stream:
            return self._stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
        # 直接生成 （一步到位）
        generated = []
        for i in range(input_ids.size(0)):
            non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
            out = self._stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args)
            tokens_list = [tokens[:, -1:] for tokens in out]
            gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
            full_sequence = torch.cat([non_pad, gen], dim=-1)
            generated.append(full_sequence)
        max_length = max(seq.size(1) for seq in generated)
        generated = [
            torch.cat([seq, torch.full((1, max_length - seq.size(1)), pad_token_id, dtype=seq.dtype, device=seq.device)],dim=-1) 
            for seq in generated
        ]
        return torch.cat(generated, dim=0)

    def _stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
        start, first_seq, past_kvs = input_ids.shape[1], True, None
        while input_ids.shape[1] < max_new_tokens - 1:
            if first_seq or not use_cache: # 若第一次生成序列 or 无 KV Cache, 每次生成传入整个 token id 序列
                out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache, **args), False
            else: # 若非第一次生成 and 有 KV Cache，每次传入最后一个 token id 与 KV Cache 进行推理加速
                out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
                           start_pos=input_ids.shape[1] - 1, **args)
            logits, past_kvs = out.logits[:, -1, :], out.past_key_values # logits.shape: (batch_size, seq_len, embed_dim)，获取最后一位 logits
            logits[:, list(set(input_ids.tolist()[0]))] /= rp # 对生成 token 进行惩罚，降低后续重复生成几率
            logits /= (temperature + 1e-9) # 调整温度，控制生成多样性
            if top_p is not None and top_p < 1.0: # top-p 采样
                sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
                sorted_probs = F.softmax(sorted_logits, dim=-1)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                sorted_indices_to_remove[:, 0] = False
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                logits[indices_to_remove] = -float('Inf')
            input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) # 从保留的 token 中采样
            input_ids = torch.cat((input_ids, input_ids_next), dim=1)
            yield input_ids[:, start:]
            if input_ids_next.item() == eos_token_id:
                break

In [14]:
MiniMind = MiniMindLM(LMConfig())
input_ids = torch.Tensor([1, 3, 5, 7]).long().reshape(1, 4)
OUT = MiniMind(input_ids, use_cache=True)
OUT.logits.shape, OUT.aux_loss, len(OUT.past_key_value)

original shape xq: torch.Size([1, 4, 8, 64])
ajusted shape xq: torch.Size([1, 4, 8, 32])
ajusted shape pos_cis： torch.Size([1, 4, 1, 32])


(torch.Size([1, 4, 6400]), 0, 1)

In [15]:
# 我们让 MiniMind 根据我们设计的 四个输入 token 生成输出
out = MiniMind.generate(input_ids, max_new_tokens=8, use_cache=True)
print(out.shape, out)

original shape xq: torch.Size([1, 4, 8, 64])
ajusted shape xq: torch.Size([1, 4, 8, 32])
ajusted shape pos_cis： torch.Size([1, 4, 1, 32])
original shape xq: torch.Size([1, 1, 8, 64])
ajusted shape xq: torch.Size([1, 1, 8, 32])
ajusted shape pos_cis： torch.Size([1, 1, 1, 32])
original shape xq: torch.Size([1, 1, 8, 64])
ajusted shape xq: torch.Size([1, 1, 8, 32])
ajusted shape pos_cis： torch.Size([1, 1, 1, 32])
torch.Size([1, 7]) tensor([[   1,    3,    5,    7, 4867, 4289, 6268]])


## Minimind MOE Model

@TODO