In [12]:
from transformers import BertModel

In [13]:
model_name = 'bert-base-uncased'

In [14]:
model = BertModel.from_pretrained(model_name)

In [15]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

### summary
- bert:encoder of transformer
    - transformer:encoder-decoder (seq2seq)

- bert
    - embeddings
        - word(token) embedding
        - position embedding
        - token type embedding(segment type)
    - encoder(12-layers)
        - self-attention
        - feed forward
    - pooler

In [16]:
# 参数量统计
total_params = 0
total_learnable_params = 0
total_embedding_params = 0
total_encoder_params = 0
total_pooler_params = 0

In [17]:
for name, params in model.named_parameters():
    # print(name, '->', params.shape)
    if 'embedding' in name:
        total_embedding_params += params.numel()
    if 'encoder' in name:
        total_encoder_params += params.numel()
    if 'pooler' in name:
        total_pooler_params += params.numel()
    if params.requires_grad:
        total_learnable_params += params.numel()
    total_params += params.numel()

In [18]:
total_params

109482240

In [19]:
total_learnable_params

109482240

In [20]:
print(total_embedding_params, total_encoder_params, total_pooler_params)

23837184 85054464 590592


In [21]:
params = [total_embedding_params, total_encoder_params, total_pooler_params]
for param in params:
    print(param / sum(params))

0.21772649152958506
0.776879099295009
0.005394409175405983
