### 全局配置

In [1]:
from dataclasses import dataclass
from typing import Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# 全局配置类
@dataclass
class ModelArgs:
    dim: int = 4096  # llama嵌入维度为4096
    n_layers: int = 32
    n_heads: int = 32 # Q的头数
    n_kv_heads: Optional[int] = None # K,V的头数 使用Group Multiple Query
    vocab_size: int = -1
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5

    # KV cache变量
    max_batch_size: int = 32
    max_seq_len: int = 2048

    device: str = None

### 总体模型结构
<img src="imgs/arch.PNG" alt="model arch" style="width:30%; height:auto;" />

In [2]:
class Transformer(nn.Module):

    def __init__(self,args: ModelArgs):
        super().__init__()

        assert args.vocab_size != -1, "未设置词表大小"

        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        # 嵌入层
        self.tok_embeddings = nn.Embedding(self.vocab_size,args.dim)

        # N层堆叠的encoder块
        self.layers = nn.ModuleList()
        for layer_id in range(args.n_layers):
            self.layers.append(EncoderBlock(args))

        self.norm = RMSNorm(args.dim,eps = args.norm_eps)
        
        self.output = nn.Linear(args.dim, self.vocab_size, bias=False)

        # RoPE位置编码
        self.freqs_complex = precompute_theta_pos_frequencies(
            self.args.dim // self.args.n_heads,
            self.args.max_seq_len * 2,
            device = self.args.device
        )

    def forward(self, tokens: torch.Tensor, start_pos: int):
        # KV-cache仅限推理!

        # (B, Seq_len)
        batch_size, seq_len = tokens.shape
        assert seq_len == 1, "KV缓存,Q仅为每次更新的一个token 一次处理一个token!"

        # (B, Seq_len) -> (B, Seq_len , Dim)
        h = self.tok_embeddings(tokens)

        # RoPE编码
        freq_complex = self.freqs_complex[start_pos:start_pos + seq_len]

        for layer in self.layers:
            h = layer(h,start_pos,freq_complex)
        
        # RMSNorm
        h = self.norm(h)

        # Linear
        output = self.output(h).float()

        # Softmax 在 loss 中
        return output

