In [1]:
import torch
from torch import nn
from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM

# 1. Model load and data preprocessing
The difference between bert BertModel and **BertForMaskedLM**:
- The last layer: pooler -> cls (768 -> 30522)

In [4]:
model_type = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(model_type)
bert = BertModel.from_pretrained(model_type)
mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
text = ("After Abraham Lincoln won the November 1860 presidential "
        "election on an anti-slavery platform, an initial seven "
        "slave states declared their secession from the country "
        "to form the Confederacy. War broke out in April 1861 "
        "when secessionist forces attacked Fort Sumter in South "
        "Carolina, just over a month after Lincoln's "
        "inauguration.")

inputs = tokenizer(text, return_tensors='pt')
list(inputs.keys())

['input_ids', 'token_type_ids', 'attention_mask']

# 2. Masking

In [17]:
inputs['labels'] = inputs['input_ids'].detach().clone() 
# Detach this tensor from the computation graph (so no gradients will be computed), 
# and create a copy to avoid in-place modification.

inputs['labels']

tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]])

In [13]:
mask = torch.rand(inputs['input_ids'].shape) < 0.15

mask

tensor([[False, False, False,  True, False, False, False, False, False, False,
         False,  True, False,  True, False, False, False,  True, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False,  True,  True, False, False, False,  True, False, False,
         False, False, False,  True,  True, False, False, False, False, False,
         False,  True, False,  True, False, False,  True, False, False, False,
         False, False]])

In [23]:
# Keep first, last and 85% token to 'False'
mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \
        * (inputs['input_ids'] != 101) \
        * (inputs['input_ids'] != 102)

mask_arr

tensor([[False,  True,  True, False, False, False, False, False, False, False,
         False, False, False, False, False,  True, False,  True,  True, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False,  True, False, False,
         False, False, False, False, False, False, False, False,  True, False,
         False, False, False, False, False,  True, False, False, False,  True,
          True, False]])

In [24]:
# Find all the index whose value is 'True'
selection = torch.flatten(mask_arr[0].nonzero()).tolist()

selection

[1, 2, 15, 17, 18, 37, 48, 55, 59, 60]

In [25]:
tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [26]:
tokenizer.vocab['[MASK]']

103

In [28]:
inputs['input_ids'][0, selection] = 103

# There are lots of locations in input_ids be valued to '103'
# labels are the real ids
inputs

{'input_ids': tensor([[  101,   103,   103,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,   103,  1010,   103,   103,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,   103,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,   103,  3792,
          1010,  2074,  2058,  1037,  3204,   103,  5367,  1005,  1055,   103,
           103,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([

### Compare input\_id and labels

In [31]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] [MASK] [MASK] lincoln won the november 1860 presidential election on an anti - slavery [MASK] , [MASK] [MASK] seven slave states declared their secession from the country to form the confederacy . war broke out in [MASK] 1861 when secession ##ist forces attacked fort sum ##ter in [MASK] carolina , just over a month [MASK] lincoln ' s [MASK] [MASK] [SEP]"

In [32]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"

# 3. Forward and calculate loss

In [33]:
mlm.eval()

with torch.no_grad():
    output = mlm(**inputs)

In [34]:
output.keys()

odict_keys(['loss', 'logits', 'hidden_states'])

### Check all the three parameters
- loss
- logits
- hidden_states

In [35]:
output.logits

tensor([[[ -7.2286,  -7.1346,  -7.1977,  ...,  -6.3683,  -6.3470,  -4.4159],
         [ -6.6458,  -7.0763,  -6.5377,  ...,  -5.2213,  -6.5514,  -9.1078],
         [ -7.0408,  -7.1989,  -6.7861,  ...,  -6.6813,  -5.6614,  -8.3547],
         ...,
         [ -4.1758,  -4.2652,  -4.1038,  ...,  -2.3127,  -4.2239,  -6.1054],
         [ -8.7734,  -8.5920,  -8.8917,  ...,  -8.0200,  -8.9393,  -3.6660],
         [-12.9003, -13.3303, -13.1156,  ..., -11.2309, -10.8652,  -7.8950]]])

In [36]:
output.loss

tensor(1.1040)

In [42]:
# num of hidden_states(13) = num of encoders(12) + embedding output(1)
len(output.hidden_states)

13

# 4. From scratch

In [47]:
# Get the last layer of hidden_states, that is the input of the mlm layer
mlm.cls(output['hidden_states'][-1])

tensor([[[ -7.2286,  -7.1346,  -7.1977,  ...,  -6.3683,  -6.3470,  -4.4159],
         [ -6.6458,  -7.0763,  -6.5377,  ...,  -5.2213,  -6.5514,  -9.1078],
         [ -7.0408,  -7.1989,  -6.7861,  ...,  -6.6813,  -5.6614,  -8.3547],
         ...,
         [ -4.1758,  -4.2652,  -4.1038,  ...,  -2.3127,  -4.2239,  -6.1054],
         [ -8.7734,  -8.5920,  -8.8917,  ...,  -8.0200,  -8.9393,  -3.6660],
         [-12.9003, -13.3303, -13.1156,  ..., -11.2309, -10.8652,  -7.8950]]],
       grad_fn=<ViewBackward0>)

In [48]:
output.logits

tensor([[[ -7.2286,  -7.1346,  -7.1977,  ...,  -6.3683,  -6.3470,  -4.4159],
         [ -6.6458,  -7.0763,  -6.5377,  ...,  -5.2213,  -6.5514,  -9.1078],
         [ -7.0408,  -7.1989,  -6.7861,  ...,  -6.6813,  -5.6614,  -8.3547],
         ...,
         [ -4.1758,  -4.2652,  -4.1038,  ...,  -2.3127,  -4.2239,  -6.1054],
         [ -8.7734,  -8.5920,  -8.8917,  ...,  -8.0200,  -8.9393,  -3.6660],
         [-12.9003, -13.3303, -13.1156,  ..., -11.2309, -10.8652,  -7.8950]]])

In [49]:
last_hidden_state = output['hidden_states'][-1]

In [50]:
mlm.eval()

with torch.no_grad():
    transformed = mlm.cls.predictions.transform(last_hidden_state)
    logits = mlm.cls.predictions.decoder(transformed)

logits

tensor([[[ -7.2286,  -7.1346,  -7.1977,  ...,  -6.3683,  -6.3470,  -4.4159],
         [ -6.6458,  -7.0763,  -6.5377,  ...,  -5.2213,  -6.5514,  -9.1078],
         [ -7.0408,  -7.1989,  -6.7861,  ...,  -6.6813,  -5.6614,  -8.3547],
         ...,
         [ -4.1758,  -4.2652,  -4.1038,  ...,  -2.3127,  -4.2239,  -6.1054],
         [ -8.7734,  -8.5920,  -8.8917,  ...,  -8.0200,  -8.9393,  -3.6660],
         [-12.9003, -13.3303, -13.1156,  ..., -11.2309, -10.8652,  -7.8950]]])

In [54]:
output.loss

tensor(1.1040)

# 5. Loss and translate

In [51]:
ce = nn.CrossEntropyLoss()

In [52]:
logits.shape

torch.Size([1, 62, 30522])

In [53]:
ce(logits[0], inputs['labels'][0].view(-1))

tensor(1.1040)

### Compare 
- initial -> input_ids
- before model -> logits[0]
- after model -> labels

In [55]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] [MASK] [MASK] lincoln won the november 1860 presidential election on an anti - slavery [MASK] , [MASK] [MASK] seven slave states declared their secession from the country to form the confederacy . war broke out in [MASK] 1861 when secession ##ist forces attacked fort sum ##ter in [MASK] carolina , just over a month [MASK] lincoln ' s [MASK] [MASK] [SEP]"

In [56]:
' '.join(tokenizer.convert_ids_to_tokens(torch.argmax(logits[0], dim=1)))

'. when abraham lincoln won the november 1860 presidential election on an anti - slavery platform , and the seven slave states declared their secession from the country to form the confederacy . war broke erupted in august 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln s s inauguration . )'

In [57]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"