### 综述区别
llama3区别于llama2在模型层面的区别主要体现在全模型使用GQA。

In [5]:
import math
import struct
import inspect
from dataclasses import dataclass
from typing import Any, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

### 模型参数定义
其中`@dataclass`参数是是`Python`中的一个装饰器，当你在类定义前加上该参数时，`Python`会自动为这个类生成一些常用的特殊方法，比如 `__init__()`、`__repr__()`、`__eq__()` 等。
- `__init__()`主要体现在不用写`class`后带一堆参数，然后再`self.xx=x`这一坨代码了,
- 还有个好处是`__repr__()`支持打印该类，具体有哪些参数,
- `__eq__()`可以用bool判断两类里面的`parameter`是否相同。

In [27]:
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 6
    n_heads: int = 6
    n_group: Optional[int] = 3
    vocab_size: int = 4096
    hidden_dim: Optional[int] = None
    multiple_of: int = 256  # MLP层隐层维度的指定计算参数(见FFN层)
    norm_eps: float = 1e-5
    max_seq_len: int = 2048
    dropout: float = 0.0

### RMS正则化，在Qwen-blog已讲。

In [7]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

### ROPE相关，在Qwen-blog已讲

In [8]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device) 
    freqs = torch.outer(t, freqs).float() 
    freqs_cos = torch.cos(freqs) 
    freqs_sin = torch.sin(freqs) 
    return freqs_cos, freqs_sin

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cos: torch.Tensor,
    freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:

    # 重塑 xq 和 xk，使其与复数表示相匹配
    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

    # 重塑形为了广播
    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

    # 应用旋转嵌入
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    # 讲最后两维度拉平。
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)


In [57]:
# 定义输入x， n_rep是需要重复的次数，在这里一般是组数
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

    bs, slen, n_kv_heads, head_dim = hidden_states.shape
    # dont need repeat here means multi head attention
    if n_rep == 1:
        return hidden_states
    # first we expand x to (bs, seq_len, head, group, head_dim)
    hidden_states = hidden_states[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim)
    # reshape make head -> head * group
    return hidden_states.reshape(bs, slen, n_kv_heads * n_rep, head_dim)

### Attention的处理与Qwen-blog大差不差.
### 在此我将flash-attn给抹掉了，手工搓计算流程更易于学习者，欢迎大佬贡献flash-attn细节教程！！

In [58]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        
        self.group = args.n_group
        self.heads = args.n_heads
        self.kv_heads = args.n_heads // args.n_group
        assert args.n_heads % self.kv_heads == 0
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout
        mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
        mask = torch.triu(mask, diagonal=1)
        self.register_buffer("mask", mask)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cos: torch.Tensor,
        freqs_sin: torch.Tensor,
    ):
        bsz, seqlen, _ = x.shape

        # QKV
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.kv_heads, self.head_dim)

        # RoPE relative positional embeddings
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

        # grouped multiquery attention: expand out keys and values
        xk = repeat_kv(xk, self.group)  # (bs, seqlen, n_local_heads, head_dim)
        xv = repeat_kv(xv, self.group)  # (bs, seqlen, n_local_heads, head_dim)

        # make heads into a batch dimension
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)


        # 先不使用flash attn，从零走一遍流程！
        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
        assert hasattr(self, 'mask')
        scores = scores + self.mask[:, :, :seqlen, :seqlen]   # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        scores = self.attn_dropout(scores)
        output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)

        # restore time as batch dimension and concat heads
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

        # 最终送入output层并正则，得到最终结果。
        output = self.wo(output)
        output = self.resid_dropout(output)
        return output



### FFN网络，其本质是一个MLP，经过线性变换，与Qwen-blog一致。

In [59]:
class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

### Block就是将上述模块组合到一起，从而形成最终的decoder-layer.

In [60]:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=args.hidden_dim,
            multiple_of=args.multiple_of,
            dropout=args.dropout,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(self, x, freqs_cos, freqs_sin):
        h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

In [61]:
class Transformer(nn.Module):
    last_loss: Optional[torch.Tensor]

    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)  # 其weight形状为(vocab,dim)
        self.dropout = nn.Dropout(params.dropout)
        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))
        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)  # 维数也为(vocab,dim)--x·W^T

        # 将模型的嵌入层（embedding layer）和输出层（unembedding layer）的权重共享，即 "权重共享" 或 "weight tying"
        self.tok_embeddings.weight = self.output.weight # 来源论文: https://paperswithcode.com/method/weight-tying

        # some useful precompute for the RoPE relative positional embeddings
        freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))

        # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
        self.last_loss = None

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        h = self.dropout(h)
        freqs_cos = self.freqs_cos[:seqlen]
        freqs_sin = self.freqs_sin[:seqlen]

        for layer in self.layers:
            # print('loging')
            h = layer(h, freqs_cos, freqs_sin)
        h = self.norm(h)

        if targets is not None:
            # 有targets则计算loss
            logits = self.output(h)
            self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # 在推理阶段，只抽取最后一行--预测下一个token即可
            logits = self.output(h[:, [-1], :]).reshape(_bsz,-1) # note: using list [-1] to preserve the time dim
            self.last_loss = None

        return logits

### 运行测试一下！

In [62]:
conf = ModelArgs()
model = Transformer(conf)

In [63]:
inputs = torch.randint(low=10, high=100, size=(10,50))
output = model(inputs)
output

tensor([[ 7.0483e-01, -9.6776e-01,  9.7280e-01,  ..., -1.2867e-01,
          1.4969e+00,  4.8647e-01],
        [-5.8376e-01,  1.7032e+00, -9.8937e-01,  ...,  1.1173e+00,
         -2.8729e-05, -2.1118e-02],
        [ 2.5072e+00, -1.5984e+00, -1.7395e+00,  ...,  1.0265e+00,
         -6.7715e-01, -2.3375e-01],
        ...,
        [ 3.8787e-01, -5.7305e-01,  1.7864e+00,  ...,  3.1860e-02,
         -1.0664e-01,  2.7426e+00],
        [-2.3487e+00,  6.1362e-02, -1.9033e-01,  ..., -1.4094e-01,
         -1.8798e-01,  2.4475e-01],
        [-5.5272e-01,  6.0121e-01, -3.0820e-01,  ..., -6.4737e-01,
         -9.3688e-01, -4.5705e-01]], grad_fn=<ViewBackward0>)

In [64]:
output.shape

torch.Size([10, 4096])