In [1]:
import json
import torch
import torch.nn as nn
import dataclasses
from llama import ModelArgs
from llama import Transformer, TransformerBlock

In [2]:
params = {
    "dim": 4096,            # hidden dim 隐藏层维度
    "n_layers": 32,         # transformer block层数
    "n_heads": 32,          # 自注意力head数量
    "vocab_size": 32000,    # 词表大小
    "multiple_of": 256,     # 用于计算transformer block 前向传播隐藏层维度
    "norm_eps": 1e-06,
}
model_args = ModelArgs(**params)

print("init model ...")
llama_model = Transformer(model_args)

init model ...


In [3]:
def module_param_num(module: nn.Module):
    return sum([p.numel() for p in module.parameters()])

In [4]:
print(llama_model)

Transformer(
  (tok_embeddings): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=4096, out_features=4096, bias=False)
        (wk): Linear(in_features=4096, out_features=4096, bias=False)
        (wv): Linear(in_features=4096, out_features=4096, bias=False)
        (wo): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=4096, out_features=11008, bias=False)
        (w2): Linear(in_features=11008, out_features=4096, bias=False)
        (w3): Linear(in_features=4096, out_features=11008, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=4096, out_features=32000, bias=False)
)


In [5]:
print(f"embed params:\t {module_param_num(llama_model.tok_embeddings):,d}")
print(f"layers params:\t {module_param_num(llama_model.layers):,d}")
print(f"norm params:\t {module_param_num(llama_model.norm):,d}")
print(f"output params:\t {module_param_num(llama_model.output):,d}")
print(f"total params:\t {module_param_num(llama_model):,d}")

embed params:	 131,072,000
layers params:	 6,476,267,520
norm params:	 4,096
output params:	 131,072,000
total params:	 6,738,415,616


In [8]:
print(f"config: {json.dumps(dataclasses.asdict(model_args), ensure_ascii=False, indent=4)}")
print(f"total params: {llama_model_params(model_args):,d}")

config: {
    "dim": 4096,
    "n_layers": 32,
    "n_heads": 32,
    "vocab_size": 32000,
    "multiple_of": 256,
    "norm_eps": 1e-06,
    "max_batch_size": 32,
    "max_seq_len": 2048
}
total params: 6,738,415,616
