Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How is the number of BERT model parameters calculated? #656

Open
aslicedbread opened this issue May 20, 2019 · 12 comments
Open

How is the number of BERT model parameters calculated? #656

aslicedbread opened this issue May 20, 2019 · 12 comments

Comments

@aslicedbread
Copy link

I‘m a bit confused about the 110M parameters. How is it calculated?

@superdu12138
Copy link

通常情况 transformer 模型有很多参数需要训练。譬如 BERT BASE 模型: L=12, H=768, A=12, 需要训练的模型参数总数是 12 * 768 * 12 = 110M

https://zhuanlan.zhihu.com/p/51413773

@superdu12138
Copy link

我对110M参数感到有点困惑。它是如何计算的?

通常情况变压器模型有很多参数需要训练。譬如BERT BASE模型:L = 12,H = 768,A = 12,需要训练的模型参数总数是12 * 768 * 12 = 110M

https://zhuanlan.zhihu.com/p/51413773

@careyee
Copy link

careyee commented Sep 3, 2019

12 * 768 * 12 = 110M ?

@laohur
Copy link

laohur commented Sep 20, 2019

here is one layer Transformer

#  parameters: 10152448 10152448      
weight name: encoder.src_word_emb.weight size: [5395, 512] count: 2762240
weight name: encoder.position_enc.weight size: [33, 512] count: 16896
weight name: encoder.layer_stack.0.slf_attn.w_qs.weight size: [512, 512] count: 262144
weight name: encoder.layer_stack.0.slf_attn.w_qs.bias size: [512] count: 512
weight name: encoder.layer_stack.0.slf_attn.w_ks.weight size: [512, 512] count: 262144
weight name: encoder.layer_stack.0.slf_attn.w_ks.bias size: [512] count: 512
weight name: encoder.layer_stack.0.slf_attn.w_vs.weight size: [512, 512] count: 262144
weight name: encoder.layer_stack.0.slf_attn.w_vs.bias size: [512] count: 512
weight name: encoder.layer_stack.0.slf_attn.layer_norm.weight size: [512] count: 512
weight name: encoder.layer_stack.0.slf_attn.layer_norm.bias size: [512] count: 512
weight name: encoder.layer_stack.0.slf_attn.fc.weight size: [512, 512] count: 262144
weight name: encoder.layer_stack.0.slf_attn.fc.bias size: [512] count: 512
weight name: encoder.layer_stack.0.pos_ffn.w_1.weight size: [2048, 512, 1] count: 1048576
weight name: encoder.layer_stack.0.pos_ffn.w_1.bias size: [2048] count: 2048
weight name: encoder.layer_stack.0.pos_ffn.w_2.weight size: [512, 2048, 1] count: 1048576
weight name: encoder.layer_stack.0.pos_ffn.w_2.bias size: [512] count: 512
weight name: encoder.layer_stack.0.pos_ffn.layer_norm.weight size: [512] count: 512
weight name: encoder.layer_stack.0.pos_ffn.layer_norm.bias size: [512] count: 512
weight name: decoder.tgt_word_emb.weight size: [5395, 512] count: 2762240
weight name: decoder.position_enc.weight size: [33, 512] count: 16896
weight name: decoder.layer_stack.0.slf_attn.w_qs.weight size: [512, 512] count: 262144
weight name: decoder.layer_stack.0.slf_attn.w_qs.bias size: [512] count: 512
weight name: decoder.layer_stack.0.slf_attn.w_ks.weight size: [512, 512] count: 262144
weight name: decoder.layer_stack.0.slf_attn.w_ks.bias size: [512] count: 512
weight name: decoder.layer_stack.0.slf_attn.w_vs.weight size: [512, 512] count: 262144
weight name: decoder.layer_stack.0.slf_attn.w_vs.bias size: [512] count: 512
weight name: decoder.layer_stack.0.slf_attn.layer_norm.weight size: [512] count: 512
weight name: decoder.layer_stack.0.slf_attn.layer_norm.bias size: [512] count: 512
weight name: decoder.layer_stack.0.slf_attn.fc.weight size: [512, 512] count: 262144
weight name: decoder.layer_stack.0.slf_attn.fc.bias size: [512] count: 512
weight name: decoder.layer_stack.0.enc_attn.w_qs.weight size: [512, 512] count: 262144
weight name: decoder.layer_stack.0.enc_attn.w_qs.bias size: [512] count: 512
weight name: decoder.layer_stack.0.enc_attn.w_ks.weight size: [512, 512] count: 262144
weight name: decoder.layer_stack.0.enc_attn.w_ks.bias size: [512] count: 512
weight name: decoder.layer_stack.0.enc_attn.w_vs.weight size: [512, 512] count: 262144
weight name: decoder.layer_stack.0.enc_attn.w_vs.bias size: [512] count: 512
weight name: decoder.layer_stack.0.enc_attn.layer_norm.weight size: [512] count: 512
weight name: decoder.layer_stack.0.enc_attn.layer_norm.bias size: [512] count: 512
weight name: decoder.layer_stack.0.enc_attn.fc.weight size: [512, 512] count: 262144
weight name: decoder.layer_stack.0.enc_attn.fc.bias size: [512] count: 512
weight name: decoder.layer_stack.0.pos_ffn.w_1.weight size: [2048, 512, 1] count: 1048576
weight name: decoder.layer_stack.0.pos_ffn.w_1.bias size: [2048] count: 2048
weight name: decoder.layer_stack.0.pos_ffn.w_2.weight size: [512, 2048, 1] count: 1048576
weight name: decoder.layer_stack.0.pos_ffn.w_2.bias size: [512] count: 512
weight name: decoder.layer_stack.0.pos_ffn.layer_norm.weight size: [512] count: 512
weight name: decoder.layer_stack.0.pos_ffn.layer_norm.bias size: [512] count: 512
weight name: tgt_word_prj.weight size: [5395, 512] count: 2762240

--
vocab size=5395 seq len=33, embedding=512, you cand replace it with BERT

@liuqiangict
Copy link

  1. bert-base-uncased, 110M parameters
Bert-base-uncased Key Shape Count  
Embedding embeddings.word_embeddings.weight [30522, 768] 23,440,896 23,837,184
  embeddings.position_embeddings.weight [512, 768] 393,216  
  embeddings.token_type_embeddings.weight [2, 768] 1,536  
  embeddings.LayerNorm.weight [768] 768  
  embeddings.LayerNorm.bias [768] 768  
Transformer * 12 encoder.layer.0.attention.self.query.weight [768, 768] 589,824 7,087,872  * 12 = 85,054,464
  encoder.layer.0.attention.self.query.bias [768] 768  
  encoder.layer.0.attention.self.key.weight [768, 768] 589,824  
  encoder.layer.0.attention.self.key.bias [768] 768  
  encoder.layer.0.attention.self.value.weight [768, 768] 589,824  
  encoder.layer.0.attention.self.value.bias [768] 768  
  encoder.layer.0.attention.output.dense.weight [768, 768] 589,824  
  encoder.layer.0.attention.output.dense.bias [768] 768  
  encoder.layer.0.attention.output.LayerNorm.weight [768] 768  
  encoder.layer.0.attention.output.LayerNorm.bias [768] 768  
  encoder.layer.0.intermediate.dense.weight [3072, 768] 2,359,296  
  encoder.layer.0.intermediate.dense.bias [3072] 3072  
  encoder.layer.0.output.dense.weight [768, 3072] 2,359,296  
  encoder.layer.0.output.dense.bias [768] 768  
  encoder.layer.0.output.LayerNorm.weight [768] 768  
  encoder.layer.0.output.LayerNorm.bias [768] 768  
Pooler pooler.dense.weight [768, 768] 589,824 590,592
  pooler.dense.bias [768] 768  
        109,482,240
  1. bert-large-uncased, 340M parameters
Bert-large-uncased Key Shape Count Count All
Embedding embeddings.word_embeddings.weight [30522, 1024] 31,254,528 31,782,912
  embeddings.position_embeddings.weight [512, 1024] 524,288  
  embeddings.token_type_embeddings.weight [2, 1024] 2,048  
  embeddings.LayerNorm.weight [1024] 1,024  
  embeddings.LayerNorm.bias [1024] 1,024  
Transformer * 24 encoder.layer.0.attention.self.query.weight [1024, 1024] 1,048,576 12,592,128 * 24 = 302,211,072
  encoder.layer.0.attention.self.query.bias [1024] 1,024  
  encoder.layer.0.attention.self.key.weight [1024, 1024] 1,048,576  
  encoder.layer.0.attention.self.key.bias [1024] 1,024  
  encoder.layer.0.attention.self.value.weight [1024, 1024] 1,048,576  
  encoder.layer.0.attention.self.value.bias [1024] 1,024  
  encoder.layer.0.attention.output.dense.weight [1024, 1024] 1,048,576  
  encoder.layer.0.attention.output.dense.bias [1024] 1,024  
  encoder.layer.0.attention.output.LayerNorm.weight [1024] 1,024  
  encoder.layer.0.attention.output.LayerNorm.bias [1024] 1,024  
  encoder.layer.0.intermediate.dense.weight [4096, 1024] 4,194,304  
  encoder.layer.0.intermediate.dense.bias [4096] 4,096  
  encoder.layer.0.output.dense.weight [1024, 4096] 4,194,304  
  encoder.layer.0.output.dense.bias [1024] 1,024  
  encoder.layer.0.output.LayerNorm.weight [1024] 1,024  
  encoder.layer.0.output.LayerNorm.bias [1024] 1,024  
Pooler pooler.dense.weight [1024, 1024] 1,048,576 1,049,600
  pooler.dense.bias [1024] 1,024  
        335,043,584

@hezq06
Copy link

hezq06 commented Feb 5, 2021

So does the attention head number get included?

@oyvistee
Copy link

I think the attention head number is chosen such that H / A = 64 for all models, where H is the hidden size and A is the number of attention heads

@amineabdaoui
Copy link

Thanks @liuqiangict
So the query, key and value weights are shared across all the attention heads of the same layer?

@careyee
Copy link

careyee commented Dec 24, 2021 via email

@jong-won-lee
Copy link

Thanks @liuqiangict So the query, key and value weights are shared across all the attention heads of the same layer?

They are different. If they are shared, weight size can be reduced by the number of heads.

@careyee
Copy link

careyee commented Jan 20, 2022 via email

@DongqiShen
Copy link

So does the attention head number get included?

Yes, It does. Actually, for each head, the attention layer project input (which is [768]) to a small size (which is [64]). There are 12 heads in attention layer. We can see that 64 * 12 = 768. The implementation in transformer do not have 12 head explicitly, otherwise, 12 head was put together which is one linear layer (768 * 768). For the code, actually, they are the same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants