# 搭建Minimind模型
首先要理清模型结构<br>
![structure](../images/LLM-structure.png)


# 搭建思路
采用自底向上的方式搭建model，理清所有的因素
- tokenizer
- embedding
- AttentionBlock
- ffn
- output


# tokenizer
在此之前已经训练好了tokenizer了，这里我们就开始利用之前训练好的tokenizer，来将原始信息转为input_ids的结构<br>
这里学习到的点：padding来统一输出input_ids的形状。left-padding和right-padding也是常见考察点<br>
注意！tokenizer初始化的时候就定好了padding的方向了，后续无法更改的。

In [1]:
from numpy import pad
from transformers import AutoTokenizer
tokenizer_path="./"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,padding_side="right") # padding_side可以是"left"或"right"，默认是"right"
#单条信息的情况,batch=1
messages = [[
        {"role": "system", "content": "你是一个优秀的聊天机器人，总是给我正确的回应！"},
        {"role": "user", "content": '你来自哪里？'},
        {"role": "assistant", "content": '我来自地球'}
    ],[
        {"role": "system", "content": "你是一个糟糕的捣乱机器人，总是给我错误的回应！"},
        {"role": "user", "content": '你来自哪里？'},
        {"role": "assistant", "content": '我来自火星'}
    ],
    ]
#多条信息

inputs_r = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True,return_tensors="pt",padding="max_length",
    truncation=True,
    max_length=100,
    padding_side="right"# 可以尝试"left"或"right"
    ) #直接使用会导致长度不一致的典型问题，因此需要padding到一致长度
print("inputs_r.shape=",inputs_r.shape)
print("inputs_r",inputs_r)
inputs_l = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True,return_tensors="pt",padding="max_length",
    truncation=True,
    max_length=100,
    padding_side="left"# 可以尝试"left"或"right"
    )
print("inputs_l.shape=",inputs_l.shape)
print("inputs_l",inputs_l)

#由于一开始tokenizer初始化的时候就定好了padding的方向了，所以后续无法更改的。上面也对比出来了

input_ids=inputs_r
print(input_ids.shape)
print(type(input_ids))
# 可知有很多tokenizer输出的tensor形状了



  from .autonotebook import tqdm as notebook_tqdm


inputs_r.shape= torch.Size([2, 100])
inputs_r tensor([[   1,   87,   93,  307,   73,   81,  203,  397,  924, 5235, 3317, 2117,
          265, 2603, 1132, 2599,  703,  472,  997,    2,  203,    1,   89,   87,
         3709,  203,  397, 2722, 3016,  425,    2,  203,    1,   69,   87,   87,
           77,  307, 3924,   88,  203,  301, 2722, 1284,    2,  203,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [   1,   87,   93,  307,   73,   81,  203,  397,  924, 5990, 2391,  328,
          240,  101, 3789, 2117,  265, 2603, 1132, 1395,  264,  703,  472,  997,
            2,  203,    1,   89,   87, 3709,  203,  397, 2722, 3016,  425,    2,
          203,    1,   69,   

# embedding
tokenizer实现了word2vec，而下一步就是将原始的类似one-hot编码转化为向量化的压缩的input_ids

In [2]:
#实现embedding
import torch
from torch import nn
# embedding
class Embedding(nn.Module):
    def __init__(self,vocab_size,embed_dim):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
    def forward(self,input_ids):
        return self.embedding(input_ids)

# 测试embedding
vocab_size = tokenizer.vocab_size
embed_dim = 128  # 嵌入维度
embedding_layer = Embedding(vocab_size, embed_dim)
# input_ids是tokenizer输出的input_ids
res= embedding_layer(input_ids)
print("res.shape=",res.shape)
print("res",res)



res.shape= torch.Size([2, 100, 128])
res tensor([[[ 0.6838, -0.2991,  0.8177,  ...,  0.8232,  1.6614,  2.9230],
         [ 0.1591, -0.4361, -0.7887,  ...,  0.1875, -0.6190,  1.5368],
         [ 0.7740, -0.9560,  0.1643,  ...,  0.1395,  0.2955, -0.0335],
         ...,
         [ 0.8112, -0.2788,  1.8459,  ...,  0.3327,  0.3586,  1.2822],
         [ 0.8112, -0.2788,  1.8459,  ...,  0.3327,  0.3586,  1.2822],
         [ 0.8112, -0.2788,  1.8459,  ...,  0.3327,  0.3586,  1.2822]],

        [[ 0.6838, -0.2991,  0.8177,  ...,  0.8232,  1.6614,  2.9230],
         [ 0.1591, -0.4361, -0.7887,  ...,  0.1875, -0.6190,  1.5368],
         [ 0.7740, -0.9560,  0.1643,  ...,  0.1395,  0.2955, -0.0335],
         ...,
         [ 0.8112, -0.2788,  1.8459,  ...,  0.3327,  0.3586,  1.2822],
         [ 0.8112, -0.2788,  1.8459,  ...,  0.3327,  0.3586,  1.2822],
         [ 0.8112, -0.2788,  1.8459,  ...,  0.3327,  0.3586,  1.2822]]],
       grad_fn=<EmbeddingBackward0>)


# AttentionBlock
主要分为三个组件
- RMSNorm
- GQA
- FFN
embedding结束后，输入向量压缩为了[2,100,128]的tensor<br>
将这个tensor输入attentionblock，捕捉tensor内部的注意力关系<br>
这里采用的GQA机制，需要实现GQA，还需要实现RoPE编码，从而捕捉tensor内部的位置和时序关系
![img](../images/LLM-structure.png)

## RMSNorm
### 均方根层归一化 (Root Mean Square Layer Normalization, RMSNorm)

RMSNorm 是对 LayerNorm 的一个改进,  没有做 re-center 操作（移除了均值项）, 可以看作 LayerNorm 在均值为零时的特例, 使用平方根均值归一化降低噪声影响。

- **Layer Norm**

$$y = \frac{x-E(x)}{\sqrt{Var(x) + \epsilon}} * \gamma + \beta$$

假设输入张量形状为 (batch_size,  sequence_length,  embedding_dim), 层归一化对 embedding_dim 维度进行归一化操作, 其中,  $\epsilon$ 是一个超参数, 用于防止分母为零导致结果上溢,  $\gamma$,  $\beta$ 均为可学习参数。

- **RMS Norm**

$$a_i=\frac{a_i}{RMS(a) + \epsilon} * \gamma,  \quad where \quad RMS(a) = \sqrt{\frac{1}{n}\sum^n_{i=1}a^2_i}.$$

假设输入张量形状为 (batch_size,  sequence_length,  embedding_dim), RMS Norm 对 embedding_dim 维度进行归一化,其中,  其中,  $\epsilon$ 是一个超参数, 用于防止分母为零导致结果上溢, $\gamma$ 为可学习参数.

不难发现, 当均值为零时, Layer Norm 退化为 RMS Norm. 这是因为 RMS Norm 在 Layer Norm 的基础上舍弃了中心化操作, 仅用缩放进行归一化, 其不改变数据原本的分布, 有利于激活函数输出的稳定.

In [3]:
class RMSNorm(nn.Module):
    def __init__(self,embed_dim,eps=1e-6):
        super(RMSNorm,self).__init__()
        self.embed_dim=embed_dim
        self.eps=eps
        self.weight=nn.Parameter(torch.ones(embed_dim))
    def forward(self,x):
        return x*torch.rsqrt(x.pow(2).mean(-1,keepdim=True)+self.eps).type_as(x)*self.weight
# 测试RMSNorm
embed_dim = 128  # 嵌入维度
rmsnorm_layer = RMSNorm(embed_dim)
input_tensor = torch.randn(2, 100, embed_dim)  # 假设输入的tensor形状为[batch_size, seq_length, embed_dim]
output_tensor = rmsnorm_layer(input_tensor)
print("input_tensor",input_tensor)
print("output_tensor",output_tensor)

input_tensor tensor([[[ 0.0882,  2.0853, -1.8059,  ..., -1.1505,  0.4109, -0.6234],
         [-0.2026,  0.2099,  0.6908,  ..., -0.3157, -1.7592, -0.2903],
         [-1.9698, -0.1924,  0.6977,  ..., -0.2409, -0.1523,  1.0376],
         ...,
         [ 0.4626, -0.1045,  1.1635,  ..., -0.3794, -1.5879, -2.1123],
         [-0.4531, -0.0845,  2.0073,  ..., -0.2041,  1.1379,  0.2551],
         [ 0.8110,  1.2606, -1.3102,  ..., -1.4146, -0.1723, -0.7322]],

        [[-0.2703, -0.2388,  0.5468,  ...,  2.2620,  1.3278, -0.0072],
         [-1.7425,  0.1245, -0.6601,  ..., -0.7695,  0.3383, -0.4106],
         [ 1.7596, -2.1032, -0.0324,  ..., -0.0683, -0.0102, -2.2862],
         ...,
         [-0.6159, -0.4573,  0.8271,  ..., -0.7213, -1.7506,  1.1421],
         [ 0.9572, -0.1877, -1.8484,  ..., -1.0934, -0.7228,  0.2555],
         [ 1.2894,  0.7856, -2.4600,  ..., -0.0700,  0.0387, -0.8012]]])
output_tensor tensor([[[ 0.0804,  1.9000, -1.6454,  ..., -1.0482,  0.3744, -0.5680],
         [-0.1931,

## RoPE
### Rotary Position Embedding, RoPE

旋转位置编码是一种能将相对位置信息集成到 self-attention 中, 进而提升 transformer 架构性能的位置编码方式, 和绝对位置编码相比, RoPE 具有很好的外推性, 是目前的主流位置编码方式.

外推性的解释, 通俗来说就是训练的时候限制了 512 的上下文长度，那么推理时如果面对超过该长度的文本，LLM 可能无法正确处理.

- **绝对位置编码**

绝对位置编码是早期 Transformer 架构采用的绝对位置编码方案，及那个每个位置映射为固定的向量表示.

$$f_{t:t\in\{q,k,v\}}(\boldsymbol{x}_i,i)=\boldsymbol{W}_{t:t\in\{q,k,v\}}(\boldsymbol{x}_i+\boldsymbol{p}_i)$$

其中编码向量 $p_i$ 的计算使用如下公式：

$$\boldsymbol{p}_{i,2t}=\sin\left(k/1000^{2t/d}\right), \boldsymbol{p}_{i,2t+1}=\cos\left(k/1000^{2t/d}\right)$$

正如其名，绝对位置编码只考虑了输入序列中的绝对位置关系，对于 token 之间的相对信息则没有纳入考虑.

- **旋转位置编码**

假定 query 和 key 的内积操作可以被函数 g 表示，该函数 g 的输入是词嵌入向量 $x_m, x_n$ 和它们之间的相对位置 $m-n$:

$$<f_q(x_m ,m), f_k(x_n, n)>=g(x_m, x_n, m, n)$$

旋转位置编码就是找到一个使上式成立的位置编码方式. 

出于认识的目的，我们省略复杂的数学推导，直接看 RoPE 的的结论：

存在这样一个正交矩阵：

$$\boldsymbol{R}_{\Theta,m}^d=\underbrace{\begin{pmatrix}\cos m\theta_0&-\sin m\theta_0&0&0&\cdots&0&0\\\sin m\theta_0&\cos m\theta_0&0&0&\cdots&0&0\\0&0&\cos m\theta_1&-\sin m\theta_1&\cdots&0&0\\0&0&\sin m\theta_1&\cos m\theta_1&\cdots&0&0\\\vdots&\vdots&\vdots&\vdots&\ddots&\vdots&\vdots\\0&0&0&0&\cdots&\cos m\theta_{d/2-1}&-\sin m\theta_{d/2-1}&-\sin m\theta_{d/2-1}\end{pmatrix}}_{\boldsymbol{W}_m}$$

其中，$\Theta=\left\{\theta_i=10000^{-2(i-1)/d},i\in[1,2,\ldots,d/2]\right\}$

我们可以将 query 和 key 的内积操作转换为与原始向量 $x$ 相关的以下等价形式：

$$
\boldsymbol{q}_m^\mathbf{T}\boldsymbol{k}_n=\left(\boldsymbol{R}_{\Theta,m}^d\boldsymbol{W}_q\boldsymbol{x}_m\right)^\mathbf{T}\left(\boldsymbol{R}_{\Theta,n}^d\boldsymbol{W}_k\boldsymbol{x}_n\right)=\boldsymbol{x}_m^\mathbf{T}\boldsymbol{W}_q\boldsymbol{R}_{\Theta,n-m}^d\boldsymbol{W}_k\boldsymbol{x}_n
$$

其中， $\boldsymbol{R}_{\Theta,n-m}^d=\left(\boldsymbol{R}_{\Theta,m}^d\right)^\mathbf{T}\boldsymbol{R}_{\Theta,n}^d$.

由于 $\boldsymbol{R}_{\Theta,m}^d$ 的稀疏性，直接使用矩阵乘法会浪费算力，因此代码中采用下述方式实现：

$$\boldsymbol{R}_{\Theta,m}^{d}\boldsymbol{x}=\begin{pmatrix}x_{0}\\x_{1}\\x_{2}\\x_{3}\\\vdots\\x_{d-2}\\x_{d-1}\end{pmatrix}\otimes\begin{pmatrix}\cos m\theta_{0}\\\cos m\theta_{0}\\\cos m\theta_{1}\\\cos m\theta_{1}\\\vdots\\\cos m\theta_{d/2-1}\\\cos m\theta_{d/2-1}\end{pmatrix}+\begin{pmatrix}-x_{1}\\x_{0}\\-x_{3}\\x_{2}\\\vdots\\-x_{d-1}\\x_{d-2}\end{pmatrix}\otimes\begin{pmatrix}\sin m\theta_{0}\\\sin m\theta_{0}\\\sin m\theta_{1}\\\sin m\theta_{1}\\\vdots\\\sin m\theta_{d/2-1}\\\sin m\theta_{d/2-1}\end{pmatrix}
$$

简而言之，RoPE就是用绝对编码的形式，表示出相对编码的关系，这样同时具有了绝对编码的简洁和相对编码的位置信息泛化性<br>
此处的ROPE的实现主要参考的是LLama的RoPE实现
[LLAMA实现](https://blog.csdn.net/m0_55846238/article/details/145728695)<br>
对旋转编码理解困难，可以参考[无痛理解RoPE](https://zhuanlan.zhihu.com/p/8306958113)


大概归纳一下，旋转编码主要两步，首先是制作 $m\Theta$ 的旋转角度的表，之后再应用这个表，用于编码qk



### 首先制作$m\Theta$

In [4]:
def precompute_pos_cis(dim,seqlen,theta=1e5):
    #这里//2是因为要把dim分成两半，前半部分用于cos，后半部分用于sin，所以只会用到一个theta
    freqs=1.0/(theta**(torch.arange(0,dim,2)[:dim//2].float()/dim))
    print("freqs.shape=",freqs.shape)
    
    m=torch.arange(seqlen,device=freqs.device)
    print("m.shape=",m.shape)

    freqs=torch.outer(m,freqs).float()
    print("freqs.shape=",freqs.shape)

    pos_cis=torch.polar(torch.ones_like(freqs),freqs)
    print("pos_cis.shape=",pos_cis.shape)
    return pos_cis

# 测试precompute_pos_cis
dim = 128  # 嵌入维度
seqlen = 100  # 序列长度
pos_cis = precompute_pos_cis(dim, seqlen)
print(type(pos_cis))

freqs.shape= torch.Size([64])
m.shape= torch.Size([100])
freqs.shape= torch.Size([100, 64])
pos_cis.shape= torch.Size([100, 64])
<class 'torch.Tensor'>


### 然后将$m\Theta$应用到旋转编码计算中去

In [5]:
def apply_rotary_emb(xq,xk,pos_cis):
    xq_=torch.view_as_complex(xq.float().reshape(*xq.shape[:-1],-1,2))
    xk_=torch.view_as_complex(xk.float().reshape(*xk.shape[:-1],-1,2))
    #由于view_as_complex,最后一维合并了，就变成了dim//2的形状
    print("xq_.shape=",xq_.shape)
    def unite_shape(pos_cis,x):
        #将pos_cis对齐x的形状
        ndim = x.ndim
        assert 0 <= 1 < ndim
        #x形状一般为(batch_size, seq_len, n_heads, head_dim//2)
        #这里确保freqs_cis与x的seq_len, head_dim//2维度一致, RoPE是对每个头分别进行的
        assert pos_cis.shape == (x.shape[1],  x.shape[-1]),f"pos_cis.shape:({pos_cis.shape}),(x.shape[1],  x.shape[-1])={(x.shape[1],  x.shape[-1])}"
        shape = [d if i == 1 or i == ndim - 1 else 1 for i,  d in enumerate(x.shape)]
        return pos_cis.view(*shape)
    pos_cis = unite_shape(pos_cis, xq_)
    #将pos_cis应用到xq_和xk_上(和输入对齐)
    print("pos_cis shape:", pos_cis.shape)
    print("xq_ shape:", xq_.shape)
    print("xk_ shape:", xk_.shape)
    xq_out=torch.view_as_real(xq_ * pos_cis).flatten(3)
    xk_out=torch.view_as_real(xk_ * pos_cis).flatten(3)
    return xq_out, xk_out
#测试一下apply_rotary_emb函数
xq = torch.randn(2, 3, 2,8)  # (bs, seqlen, dim)
xk = torch.randn(2, 3, 2,8)  # (bs, seqlen, dim)
pos_cis = precompute_pos_cis(dim=8, seqlen=3, theta=1e5)
print(type(pos_cis))
xq_out, xk_out = apply_rotary_emb(xq, xk, pos_cis)
print("xq shape:", xq.shape)  # 应该是 (bs, seqlen, dim)
print("xk shape:", xk.shape)  # 应该是 (bs, seqlen, dim)

print("xq_out_shape:", xq_out.shape)  # 应该是 (bs, seqlen, dim
print("xk_out_shape:", xk_out.shape)  # 应该是 (bs, seqlen, dim
# print("xq_out:", xq_out)
# print("xk_out:", xk_out)
print("pos_cis:", pos_cis)
print("pos_cis shape:", pos_cis.shape)



freqs.shape= torch.Size([4])
m.shape= torch.Size([3])
freqs.shape= torch.Size([3, 4])
pos_cis.shape= torch.Size([3, 4])
<class 'torch.Tensor'>
xq_.shape= torch.Size([2, 3, 2, 4])
pos_cis shape: torch.Size([1, 3, 1, 4])
xq_ shape: torch.Size([2, 3, 2, 4])
xk_ shape: torch.Size([2, 3, 2, 4])
xq shape: torch.Size([2, 3, 2, 8])
xk shape: torch.Size([2, 3, 2, 8])
xq_out_shape: torch.Size([2, 3, 2, 8])
xk_out_shape: torch.Size([2, 3, 2, 8])
pos_cis: tensor([[ 1.0000+0.0000e+00j,  1.0000+0.0000e+00j,  1.0000+0.0000e+00j,
          1.0000+0.0000e+00j],
        [ 0.5403+8.4147e-01j,  0.9984+5.6204e-02j,  1.0000+3.1623e-03j,
          1.0000+1.7783e-04j],
        [-0.4161+9.0930e-01j,  0.9937+1.1223e-01j,  1.0000+6.3245e-03j,
          1.0000+3.5566e-04j]])
pos_cis shape: torch.Size([3, 4])


# GQA
![img](../images/LLM-structure.png)

### 对齐q与kv的工具函数

In [6]:
import torch
def repeat_kv_heads(x,rep_num):
    """
    重复kv的头部
    :param x: (bs, seqlen, n_heads, head_dim)
    :param rep_num: 重复的次数
    :return: (bs, seqlen, n_heads*rep_num, head_dim)
    """
    if rep_num == 1:
        return x
    bs, seqlen, kv_head_num, head_dim = x.shape
    return (
        x[:,:,:,None,:].expand(bs,seqlen, kv_head_num, rep_num, head_dim).reshape(bs, seqlen, kv_head_num * rep_num, head_dim)
    )

#测试一下
x= torch.randn(2, 3, 4, 8)  # (bs, seqlen, n_heads, head_dim)
rep_num = 2  # 重复次数
x_repeated = repeat_kv_heads(x, rep_num)
print("x:", x)  # 原始张量
print("x_repeated:", x_repeated)  # 重复后的张量
print("x shape:", x.shape)  # 应该是 (bs, seqlen, n_heads, head_dim)
print("x_repeated shape:", x_repeated.shape)  # 应该是 (bs

x: tensor([[[[-0.6247,  0.7589, -0.7877,  0.3288, -1.5248, -0.6347,  0.8084,
            0.2233],
          [ 0.8526,  0.6423, -0.5520, -1.8552,  0.8973, -1.5722, -0.5682,
            0.0765],
          [-0.8539, -0.2294,  0.8187, -0.8471, -0.4265,  2.6147, -1.3418,
           -0.4685],
          [ 0.2785, -1.2338,  1.9119, -0.2724, -0.8014, -0.2021, -0.0670,
            0.8867]],

         [[ 1.1346,  0.4250, -0.3245, -1.0280, -0.4393,  2.6461,  0.2373,
           -0.1676],
          [ 0.7621,  0.3029, -1.9954,  2.1155, -0.8595, -1.0120, -0.7186,
            0.5218],
          [-1.3487, -1.7394, -0.1029, -0.4240, -1.2830,  0.2299,  0.7985,
            2.4696],
          [-0.2083,  0.4519,  0.3335,  1.1183, -0.0557,  1.3877, -0.9095,
           -0.5147]],

         [[ 1.2585,  0.2331, -0.0595, -0.9119,  1.2068, -1.7038,  0.0242,
           -0.0851],
          [ 0.0468, -0.2321, -3.1558, -0.3373,  0.7312, -0.3929, -0.3298,
           -0.6969],
          [-0.2102, -0.2781,  0.6874,  0.73

# GQA

In [9]:
import torch
from torch import dropout, nn 
from torch.nn import functional as F
import math
class GroupQueryAttention(nn.Module):
    def __init__(self,embed_dim,head_num,kv_head_num,dropout=0.1,Flash=False,training=True,max_seqlen=100):
        super(GroupQueryAttention, self).__init__()
        ### 基本参数
        self.embed_dim=embed_dim
        self.head_num=head_num
        self.kv_head_num=kv_head_num
        assert head_num % kv_head_num == 0, "head_num must be divisible by kv_head_num"
        self.head_dim=embed_dim // head_num
        assert self.head_dim* head_num == embed_dim, "embed_dim must be divisible by head_num"
        self.rep_num = head_num // kv_head_num
        self.dropout=dropout
        self.scale = math.sqrt(self.head_dim)  # 缩放因子，通常是head_dim的平方根的倒数
        self.max_seqlen=max_seqlen  # 最大序列长度

        ### 网络结构
        self.q_proj=nn.Linear(self.embed_dim,self.head_dim*self.head_num)
        self.k_proj=nn.Linear(self.embed_dim,self.kv_head_num*self.head_dim)
        self.v_proj=nn.Linear(self.embed_dim,self.kv_head_num*self.head_dim)
        self.o_proj=nn.Linear(self.head_num*self.head_dim,self.embed_dim)
        ### dropout等正则化层设置
        self.attn_dropout=nn.Dropout(dropout)
        self.res_dropout=nn.Dropout(dropout)
        ### 其他参数
        self.Flash = hasattr(F,'scaled_dot_product_attention') and Flash # 是否使用Flash Attention
        self.training=training
        ### 因果掩码初始化
        ### 因果掩码作用的是q*k^T的结果,q=(bs,seqlen,head_num,head_dim),k=(bs,seqlen,kv_head_num*rep_num,head_dim)
        ### q*k^T =(bs,seqlen,head_num,kv_head_num*rep_num)
        mask=torch.full((1,1,self.max_seqlen,self.max_seqlen),float("-1e9"))
        mask=torch.tril(mask,diagonal=0)
        print("mask=",mask)
        print("mask.shape=",mask.shape)
        self.register_buffer("mask",mask)
    
    def forward(self,
                x,
                pos_cis=None,
                past_key_value=None,
                use_cache=False):
        """
        :param x: (bs, seqlen, embed_dim)
        :param pos_cis: (seqlen, head_dim//2) or None
        :param past_key_value: (bs, seqlen, kv_head_num*rep_num, head_dim) or None
        :param use_cache: 是否使用kv_cache
        :return:
        """
        bs, seqlen, embed_dim = x.shape
        ### qkv
        xq=self.q_proj(x).reshape(bs,seqlen,self.head_num,self.head_dim)
        xk=self.k_proj(x).reshape(bs,seqlen,self.kv_head_num,self.head_dim)
        xv=self.v_proj(x).reshape(bs,seqlen,self.kv_head_num,self.head_dim)
        xk= repeat_kv_heads(xk,self.rep_num)
        xv= repeat_kv_heads(xv,self.rep_num)
        ### 位置编码
        if pos_cis is not None:
            xq, xk = apply_rotary_emb(xq, xk, pos_cis)
        else:
            pos_cis= precompute_pos_cis(self.head_dim, seqlen, theta=1e5)
            xq, xk = apply_rotary_emb(xq, xk, pos_cis)
        ### kv_cache，仅推理模型可用
        if past_key_value is not None:
            xk = torch.cat([past_key_value[0], xk], dim=1)
            xv = torch.cat([past_key_value[1], xv], dim=1)
        past_kv= (xk, xv) if use_cache else None
        xq,xk,xv=(
            xq.transpose(1,2),
            xk.transpose(1,2),
            xv.transpose(1,2)
        )
        ### 计算注意力
        dropout_p=self.dropout if self.training else 0.0
        if self.Flash:
            attn_out=F.scaled_dot_product_attention(
                xq,xk,xv,
                attn_mask=None, # 这里没有mask是因为is_causal=True时，Flash Attention会自动加上掩码计算
                dropout_p=dropout_p,
                is_causal=True
            )
        else:
            attn_scores=torch.matmul(xq,xk.transpose(-2,-1)) / self.scale
            attn_scores+=self.mask[:,:,:seqlen,:seqlen]
            attn_weights=F.softmax(attn_scores,dim=-1)
            attn_weights=self.attn_dropout(attn_weights)
            attn_out=torch.matmul(attn_weights,xv)
            attn_out=attn_out.transpose(1,2).reshape(bs,seqlen,self.head_num*self.head_dim)
        ### 输出
        attn_out=self.o_proj(attn_out)
        attn_out=self.res_dropout(attn_out)
        return attn_out, past_kv
# 测试GroupQueryAttention
embed_dim = 128  # 嵌入维度
head_num = 8  # 注意力头数
kv_head_num = 4  # kv头数
dropout = 0.1  # dropout率
max_seqlen=100
group_query_attention = GroupQueryAttention(embed_dim, head_num, kv_head_num, dropout, max_seqlen=max_seqlen)
# input_tensor是tokenizer输出的input_ids
input_tensor = torch.randn(2, 100, embed_dim)  # 假设
output_tensor, past_kv = group_query_attention(input_tensor)
print("output_tensor.shape=", output_tensor.shape)  # 输出张量的形状
print("past_kv:", past_kv)  # past_kv的内容


mask= tensor([[[[-1.0000e+09,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
            0.0000e+00,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09]]]])
mask.shape= torch.Size([1, 1, 100, 100])
freqs.shape= torch.Size([8])
m.shape= torch.Size([100])
freqs.shape= torch.Size([100, 8])
pos_cis.shape= torch.Size([100, 8])
xq_.shape= torch.Size([2, 100, 8, 8])
pos_cis shape: torch.Size([1, 100, 1, 8])
xq_ shape: torch.Size([2, 100, 8, 8])
xk_ shape: torch.Size([2, 100, 8, 8])
output_tens

### FFN

In [10]:
import torch 
from torch import nn
from torch.nn import functional as F
class FFN(nn.Module):
    def __init__(self,embed_dim,ffn_dim,dropout=0.1):
        super(FFN,self).__init__()
        ### 自有参数
        self.embed_dim=embed_dim
        self.ffn_dim=ffn_dim
        self.dropout=dropout
        ### 网络结构
        self.gate_proj=nn.Linear(embed_dim,ffn_dim)
        self.up_proj=nn.Linear(embed_dim,ffn_dim)
        self.down_proj=nn.Linear(ffn_dim,embed_dim)
        ### dropout
        self.gate_dropout=nn.Dropout(dropout)
    def forward(self,x):
        residual=F.silu(self.gate_proj(x))
        residual=self.gate_dropout(residual)
        x=self.up_proj(x)
        x+= residual
        x=self.down_proj(x)
        return x
# 测试FFN
embed_dim = 128  # 嵌入维度
ffn_dim = 512  # FFN维度
dropout = 0.1  # dropout率
ffn_layer = FFN(embed_dim, ffn_dim, dropout)
# input_tensor是tokenizer输出的input_ids
input_tensor = torch.randn(2, 100, embed_dim)  # 假设
output_tensor = ffn_layer(input_tensor)
print("output_tensor.shape=", output_tensor.shape)  # 输出张量的形状
print("output_tensor", output_tensor)  # 输出张量的内容

output_tensor.shape= torch.Size([2, 100, 128])
output_tensor tensor([[[-3.6558e-01, -5.2952e-02,  3.9966e-01,  ...,  3.5406e-01,
           4.1486e-02, -7.8592e-01],
         [ 1.3671e-01, -2.2398e-01, -1.4500e+00,  ..., -6.2017e-02,
           2.2749e-01,  1.1504e+00],
         [-1.3193e-01,  3.6840e-01,  2.3067e-01,  ...,  7.8578e-02,
           1.7384e-01, -3.6819e-01],
         ...,
         [-2.5162e-01,  3.0838e-01, -9.8149e-02,  ..., -1.9799e-01,
          -4.4678e-01,  5.3567e-01],
         [-1.0654e-01, -7.2246e-02,  1.0375e-01,  ..., -7.3212e-02,
           3.0868e-01,  6.0068e-01],
         [ 2.8243e-01, -1.4833e-01, -2.1640e-01,  ...,  4.5890e-02,
           6.4178e-02,  2.3454e-01]],

        [[-2.7831e-02, -2.7405e-01, -5.7016e-02,  ..., -4.3016e-01,
           2.2789e-01,  2.6693e-01],
         [-2.3732e-01, -2.1335e-01,  3.4846e-04,  ..., -4.3830e-01,
          -1.6731e-01,  3.3049e-01],
         [ 7.7521e-01, -1.0378e-01,  7.2096e-01,  ...,  3.1679e-01,
           2.02

# 搭建block
![img](../images/LLM-structure.png)

In [11]:
from turtle import pos
import torch 
from torch import nn
from torch.nn import functional as F
class MiniMindBlock(nn.Module):
    def __init__(self,layer_id,seqlen,embed_dim,head_num,kv_head_num,ffn_dim):
        super(MiniMindBlock,self).__init__()
        ### 基本参数
        self.layer_id=layer_id #记录layer编号
        self.seqlen=seqlen # 序列长度
        self.embed_dim=embed_dim
        self.head_num=head_num
        self.kv_head_num=kv_head_num
        assert head_num % kv_head_num == 0, "head_num must be divisible by kv_head_num"
        self.head_dim=embed_dim // head_num
        assert self.head_dim* head_num == embed_dim, "embed_dim must be divisible by head_num"
        self.rep_num = head_num // kv_head_num
        self.ffn_dim=ffn_dim

        ### 网络结构
        self.norm1=RMSNorm(embed_dim)
        self.attn=GroupQueryAttention(embed_dim,head_num,kv_head_num)
        self.norm2=RMSNorm(embed_dim)
        self.ffn=FFN(embed_dim,ffn_dim)
        ### 额外的初始化
        pos_cis=precompute_pos_cis(self.head_dim,self.seqlen)
        self.register_buffer("pos_cis",pos_cis)
    def forward(self,x,
                past_key_value=None,
                use_cache=False):
        """
        :param x: (bs, seqlen, embed_dim)
        :param past_key_value: (bs, seqlen, kv_head_num*rep_num, head_dim) or None
        :param use_cache: 是否使用kv_cache
        :return:
        """
        ### 1. norm1
        x=self.norm1(x)
        ### 2. attn
        attn_out, past_kv = self.attn(x,pos_cis=self.pos_cis,past_key_value=past_key_value,use_cache=use_cache)
        ### 3. norm2
        x=x+attn_out
        x=self.norm2(x)
        ### 4. ffn
        x=x+self.ffn(x)
        return x, past_kv
# 测试MiniMindBlock
layer_id = 0  # 层编号
seqlen = 100  # 序列长度
embed_dim = 128  # 嵌入维度
head_num = 8  # 注意力头数
kv_head_num = 4  # kv头数
ffn_dim = 512  # FFN维度
mini_mind_block = MiniMindBlock(layer_id, seqlen, embed_dim, head_num, kv_head_num, ffn_dim)
# input_tensor是tokenizer输出的input_ids
input_tensor = torch.randn(2, 100, embed_dim)  # 假设
output_tensor, past_kv = mini_mind_block(input_tensor)
print("output_tensor.shape=", output_tensor.shape)  # 输出张量的形状
print("past_kv:", past_kv)  # past_kv的内容

mask= tensor([[[[-1.0000e+09,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
            0.0000e+00,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09]]]])
mask.shape= torch.Size([1, 1, 100, 100])
freqs.shape= torch.Size([8])
m.shape= torch.Size([100])
freqs.shape= torch.Size([100, 8])
pos_cis.shape= torch.Size([100, 8])
xq_.shape= torch.Size([2, 100, 8, 8])
pos_cis shape: torch.Size([1, 100, 1, 8])
xq_ shape: torch.Size([2, 100, 8, 8])
xk_ shape: torch.Size([2, 100, 8, 8])
output_tens

# Minimind_Dense

In [14]:
import torch
from torch import nn
from torch.nn import functional as F
class MiniMindLM(nn.Module):
    def __init__(self,vocab_size,embed_dim,head_num,kv_head_num,ffn_dim,num_layers,max_seqlen=100):
        super(MiniMindLM,self).__init__()
        ### 基本参数
        self.vocab_size=vocab_size
        self.embed_dim=embed_dim
        self.head_num=head_num
        self.kv_head_num=kv_head_num
        assert head_num % kv_head_num == 0, "head_num must be divisible by kv_head_num"
        self.head_dim=embed_dim // head_num
        assert self.head_dim* head_num == embed_dim, "embed_dim must be divisible by head_num"
        self.rep_num = head_num // kv_head_num
        self.ffn_dim=ffn_dim
        self.num_layers=num_layers
        self.max_seqlen=max_seqlen

        ### 网络结构
        self.embedding=Embedding(vocab_size,embed_dim)
        self.blocks=nn.ModuleList([
            MiniMindBlock(i,self.max_seqlen,embed_dim,head_num,kv_head_num,ffn_dim)
            for i in range(num_layers)
        ])
        self.norm=RMSNorm(embed_dim)
        self.lm_head=nn.Linear(embed_dim,vocab_size)
    def forward(self,x):
        """
        :param x: (bs, seqlen)
        :return: (bs, seqlen, vocab_size)
        """
        bs, seqlen = x.shape
        ### 1. embedding
        x=self.embedding(x)
        ### 2. blocks
        past_kv=None
        for i,block in enumerate(self.blocks):
            x,past_kv=block(x,past_key_value=past_kv,use_cache=False)
        ### 3. norm
        x=self.norm(x)
        ### 4. lm_head
        logits=self.lm_head(x)
        logits=F.softmax(logits, dim=-1)  # 对最后一维进行softmax
        return logits
# 测试MiniMindLM
vocab_size = tokenizer.vocab_size
embed_dim = 128  # 嵌入维度
head_num = 8  # 注意力头数
kv_head_num = 4  # kv头数
ffn_dim = 512  # FFN维度
num_layers = 6  # 层数
max_seqlen = 100  # 最大序列长度
mini_mind_lm = MiniMindLM(vocab_size, embed_dim, head_num,
                            kv_head_num, ffn_dim, num_layers, max_seqlen=max_seqlen)
# input_tensor是tokenizer输出的input_ids
input_tensor = torch.randint(0, vocab_size, (2, 100))  #
print("input_tensor=", input_tensor)  # 假设输入的tensor形状为[batch_size, seq_length]
# 假设输入的tensor形状为[batch_size, seq_length]
output_tensor = mini_mind_lm(input_tensor)
print("output_tensor=", output_tensor)  # 输出张量的内容
print("output_tensor.shape=", output_tensor.shape)  # 输出张量的形状


mask= tensor([[[[-1.0000e+09,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
            0.0000e+00,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09]]]])
mask.shape= torch.Size([1, 1, 100, 100])
freqs.shape= torch.Size([8])
m.shape= torch.Size([100])
freqs.shape= torch.Size([100, 8])
pos_cis.shape= torch.Size([100, 8])
mask= tensor([[[[-1.0000e+09,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-1.0000e+09, -1.0000e+09,  0.0000e+00,  ...