In [1]:
import json
import torch
import torch.nn as nn
import dataclasses
from llama import ModelArgs
from llama import Transformer, TransformerBlock

In [2]:
params = {
    "dim": 4096,            # hidden dim 隐藏层维度
    "n_layers": 32,         # transformer block层数
    "n_heads": 32,          # 自注意力head数量
    "vocab_size": 32000,    # 词表大小
    "multiple_of": 256,     # 用于计算transformer block 前向传播隐藏层维度
    "norm_eps": 1e-06,
}
model_args = ModelArgs(**params)

In [3]:
print("init model ...")
llama_model = Transformer(model_args)

init model ...


In [4]:
print(llama_model)

Transformer(
  (tok_embeddings): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=4096, out_features=4096, bias=False)
        (wk): Linear(in_features=4096, out_features=4096, bias=False)
        (wv): Linear(in_features=4096, out_features=4096, bias=False)
        (wo): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=4096, out_features=11008, bias=False)
        (w2): Linear(in_features=11008, out_features=4096, bias=False)
        (w3): Linear(in_features=4096, out_features=11008, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=4096, out_features=32000, bias=False)
)


In [5]:
# 计算模型参数量
def stats_module_params(module: nn.Module):
    return sum([p.numel() for p in module.parameters()])

LLaMA模型参数量计算公式：
$$vocabSize * dim + nLayers * (dim * dim * 4 + dim * ffDim * 3 + dim + dim) + (dim + dim * vocabSize) $$

In [9]:
def llama_params(args: ModelArgs):
    # 计算transformer block中前向传播隐藏层的维度
    ff_hidden_dim = int(2 * args.dim * 4 / 3)
    ff_hidden_dim = args.multiple_of * ((ff_hidden_dim + args.multiple_of - 1) // args.multiple_of)

    # 根据模型配置统计模型参数量
    # 1. Embedding层
    # 2. Transformer block层
    # 3. 输出层
    total_params = args.vocab_size * args.dim \
                   + args.n_layers * (args.dim * args.dim * 4 + args.dim * ff_hidden_dim * 3 + args.dim + args.dim) \
                   + args.dim + args.dim * args.vocab_size
                   
    return total_params

In [7]:
assert stats_module_params(llama_model) == llama_params(model_args)
print(f"llama 7B model params: {llama_params(model_args):,d}")

llama 7B model params: 6,738,415,616


In [8]:
print(f"embed params:\t {stats_module_params(llama_model.tok_embeddings):,d}")
print(f"layers params:\t {stats_module_params(llama_model.layers):,d}")
print(f"norm params:\t {stats_module_params(llama_model.norm):,d}")
print(f"output params:\t {stats_module_params(llama_model.output):,d}")
print(f"total params:\t {stats_module_params(llama_model):,d}")

embed params:	 131,072,000
layers params:	 6,476,267,520
norm params:	 4,096
output params:	 131,072,000
total params:	 6,738,415,616
