In [1]:
import os#环境代理设置
os.environ["http_proxy"] = "http://127.0.0.1:7890"
os.environ["https_proxy"] = "http://127.0.0.1:7890"

In [2]:
import torch
from torch import nn
from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM

### 1. model load and data preprocessing

In [6]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert = BertModel.from_pretrained(model_name)
mlm = BertForMaskedLM.from_pretrained(model_name, output_hidden_states=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


- 输出头
- BertModel
    - pooler=>BertPooler
        - dense: Linear(768,768)
        - activation: Tanh()
- BertForMaskedLM
    - cls=>BertOnlyMLMHead
        - predictions=>BertLMPredictionHead
            - transform=>BertPredictionHeadTransform
                - dense: Linear(768,768)
                - LayerNorm
            - decoder: Linear(768,30522) # (hidden_size, vocab_size) 多分类预测


In [7]:
text = ("After Abraham Lincoln won the November 1860 presidential "
        "election on an anti-slavery platform, an initial seven "
        "slave states declared their secession from the country "
        "to form the Confederacy. War broke out in April 1861 "
        "when secessionist forces attacked Fort Sumter in South "
        "Carolina, just over a month after Lincoln's "
        "inauguration.")

In [13]:
inputs = tokenizer(text, return_tensors='pt')

In [14]:
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [15]:
inputs['input_ids'].shape

torch.Size([1, 62])

In [19]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"

### 2. masking

In [25]:
inputs['labels'] = inputs['input_ids'].detach().clone()

In [30]:
inputs['input_ids'].shape

torch.Size([1, 62])

In [70]:
mask = torch.rand(inputs['input_ids'].shape) < 0.15
mask

tensor([[ True, False, False, False, False,  True, False,  True, False,  True,
         False, False, False,  True, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,  True,
         False, False, False, False, False, False, False, False, False, False,
          True, False, False,  True, False, False, False, False, False, False,
         False, False,  True, False,  True, False, False, False, False,  True,
          True, False]])

In [None]:
sum(mask[0]) # 12个token被mask

tensor(12)

In [72]:
# 防止[CLS]和[SEP]被mask
mask_arr = mask \
        * (inputs['input_ids'] != 101) \
        * (inputs['input_ids'] != 102)

In [73]:
mask_arr

tensor([[False, False, False, False, False,  True, False,  True, False,  True,
         False, False, False,  True, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False,  True,
         False, False, False, False, False, False, False, False, False, False,
          True, False, False,  True, False, False, False, False, False, False,
         False, False,  True, False,  True, False, False, False, False,  True,
          True, False]])

In [74]:
sum(mask_arr[0])

tensor(11)

In [78]:
# 被mask的token位置
selection = torch.flatten(mask_arr[0].nonzero()).tolist()

In [79]:
selection

[5, 7, 9, 13, 29, 40, 43, 52, 54, 59, 60]

In [80]:
tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [86]:
tokenizer.vocab['[MASK]']

103

In [90]:
inputs['input_ids'][0, selection] = 103

In [92]:
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,   103,  2281,   103,  4883,   103,
          2006,  2019,  3424,   103,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,   103,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
           103,  2923,  2749,   103,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,   103,  1037,   103,  2044,  5367,  1005,  1055,   103,
           103,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([

In [93]:
' ' .join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won [MASK] november [MASK] presidential [MASK] on an anti [MASK] slavery platform , an initial seven slave states declared their secession from the country to [MASK] the confederacy . war broke out in april 1861 when [MASK] ##ist forces [MASK] fort sum ##ter in south carolina , just [MASK] a [MASK] after lincoln ' s [MASK] [MASK] [SEP]"

### 3. forward and calculate loss

In [94]:
mlm.eval()
with torch.no_grad():
    output = mlm(**inputs)

In [95]:
output.keys()

odict_keys(['loss', 'logits', 'hidden_states'])

In [97]:
output.logits.shape

torch.Size([1, 62, 30522])

In [98]:
output.loss

tensor(0.7554)

In [None]:
len(output.hidden_states) # embedding*1 + BertLayer*12

13

In [None]:
output.hidden_states[-1].shape # 最后一层layer的输出

torch.Size([1, 62, 768])

### 4. from scratch

In [112]:
mlm.cls(output.hidden_states[-1])

tensor([[[ -7.3198,  -7.2361,  -7.2918,  ...,  -6.4323,  -6.4228,  -4.5251],
         [-11.4993, -11.4074, -11.5982,  ..., -11.0772, -10.3047,  -8.1416],
         [ -7.9648,  -8.2662,  -7.4894,  ...,  -6.8304,  -6.4827,  -7.5817],
         ...,
         [ -3.7624,  -3.8744,  -3.7199,  ...,  -1.8300,  -3.8536,  -5.8013],
         [ -8.7708,  -8.5767,  -8.8928,  ...,  -7.9420,  -9.0211,  -3.4241],
         [-12.8421, -13.2052, -13.0278,  ..., -11.0901, -10.7410,  -7.4169]]],
       grad_fn=<ViewBackward0>)

In [104]:
output.logits

tensor([[[ -7.3198,  -7.2361,  -7.2918,  ...,  -6.4323,  -6.4228,  -4.5251],
         [-11.4993, -11.4074, -11.5982,  ..., -11.0772, -10.3047,  -8.1416],
         [ -7.9648,  -8.2662,  -7.4894,  ...,  -6.8304,  -6.4827,  -7.5817],
         ...,
         [ -3.7624,  -3.8744,  -3.7199,  ...,  -1.8300,  -3.8536,  -5.8013],
         [ -8.7708,  -8.5767,  -8.8928,  ...,  -7.9420,  -9.0211,  -3.4241],
         [-12.8421, -13.2052, -13.0278,  ..., -11.0901, -10.7410,  -7.4169]]])

In [105]:
last_hidden_state = output['hidden_states'][-1]

In [106]:
last_hidden_state.shape

torch.Size([1, 62, 768])

In [114]:
mlm.eval()
with torch.no_grad():
    transformed = mlm.cls.predictions.transform(last_hidden_state)
    print(transformed.shape)
    logits = mlm.cls.predictions.decoder(transformed)
    print(logits.shape)
logits

torch.Size([1, 62, 768])
torch.Size([1, 62, 30522])


tensor([[[ -7.3198,  -7.2361,  -7.2918,  ...,  -6.4323,  -6.4228,  -4.5251],
         [-11.4993, -11.4074, -11.5982,  ..., -11.0772, -10.3047,  -8.1416],
         [ -7.9648,  -8.2662,  -7.4894,  ...,  -6.8304,  -6.4827,  -7.5817],
         ...,
         [ -3.7624,  -3.8744,  -3.7199,  ...,  -1.8300,  -3.8536,  -5.8013],
         [ -8.7708,  -8.5767,  -8.8928,  ...,  -7.9420,  -9.0211,  -3.4241],
         [-12.8421, -13.2052, -13.0278,  ..., -11.0901, -10.7410,  -7.4169]]])

### 5. loss and translate

In [115]:
loss = nn.CrossEntropyLoss()

In [116]:
inputs['labels'].shape

torch.Size([1, 62])

In [119]:
inputs['labels'][0].view(-1).shape

torch.Size([62])

In [120]:
loss(logits[0], inputs['labels'][0])

tensor(0.7554)

In [121]:
torch.argmax(logits, dim=-1)

tensor([[ 1012,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  3693,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  8055,  2749,  4110,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  2420,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,  1012]])

In [123]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won [MASK] november [MASK] presidential [MASK] on an anti [MASK] slavery platform , an initial seven slave states declared their secession from the country to [MASK] the confederacy . war broke out in april 1861 when [MASK] ##ist forces [MASK] fort sum ##ter in south carolina , just [MASK] a [MASK] after lincoln ' s [MASK] [MASK] [SEP]"

In [125]:
' '.join(tokenizer.convert_ids_to_tokens(torch.argmax(logits[0], dim=-1)))

". after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to join the confederacy . war broke out in april 1861 when secession confederate forces captured fort sum ##ter in south carolina , just over days month after lincoln ' s inauguration . ."

In [None]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"