In [8]:
from transformers import BertForMaskedLM, BertTokenizer, BertModel
import torch
from torch import nn

## 1. model load and data processing

In [9]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert = BertModel.from_pretrained(model_name)
mlm = BertForMaskedLM.from_pretrained(model_name, output_hidden_states=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if

In [10]:
bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [11]:
mlm

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [49]:
text = ("After Abraham Lincoln won the November 1860 presidential "
        "election on an anti-slavery platform, an initial seven "
        "slave states declared their secession from the country "
        "to form the Confederacy. War broke out in April 1861 "
        "when secessionist forces attacked Fort Sumter in South "
        "Carolina, just over a month after Lincoln's "
        "inauguration.")
text

"After Abraham Lincoln won the November 1860 presidential election on an anti-slavery platform, an initial seven slave states declared their secession from the country to form the Confederacy. War broke out in April 1861 when secessionist forces attacked Fort Sumter in South Carolina, just over a month after Lincoln's inauguration."

In [14]:
inputs = tokenizer(text,return_tensors='pt')
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [15]:
inputs['input_ids'].shape

torch.Size([1, 62])

In [17]:
tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

['[CLS]',
 'after',
 'abraham',
 'lincoln',
 'won',
 'the',
 'november',
 '1860',
 'presidential',
 'election',
 'on',
 'an',
 'anti',
 '-',
 'slavery',
 'platform',
 ',',
 'an',
 'initial',
 'seven',
 'slave',
 'states',
 'declared',
 'their',
 'secession',
 'from',
 'the',
 'country',
 'to',
 'form',
 'the',
 'confederacy',
 '.',
 'war',
 'broke',
 'out',
 'in',
 'april',
 '1861',
 'when',
 'secession',
 '##ist',
 'forces',
 'attacked',
 'fort',
 'sum',
 '##ter',
 'in',
 'south',
 'carolina',
 ',',
 'just',
 'over',
 'a',
 'month',
 'after',
 'lincoln',
 "'",
 's',
 'inauguration',
 '.',
 '[SEP]']

## 2. masking

In [19]:
inputs['labels'] = inputs['input_ids'].detach().clone()
inputs['labels']

tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]])

In [26]:
mask = torch.rand(inputs['input_ids'].shape) < 0.15
mask

tensor([[False, False, False, False, False, False, False, False,  True, False,
         False, False, False, False, False, False, False, False, False, False,
          True, False, False, False, False, False, False,  True, False, False,
         False,  True, False, False,  True, False, False, False,  True, False,
         False, False,  True,  True,  True,  True, False,  True, False, False,
          True, False, False, False,  True, False, False, False, False,  True,
         False, False]])

In [27]:
sum(mask[0])

tensor(14)

In [30]:
mask_arr = (torch.rand(inputs['input_ids'].shape)<0.15) * (inputs['input_ids']!=101) * (inputs['input_ids']!=102)
mask_arr
sum(mask_arr[0])

tensor(7)

In [34]:
selection = (torch.flatten(mask_arr[0])!=0).tolist()
selection

[False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 True,
 False,
 True,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False]

In [36]:
tokenizer.convert_tokens_to_ids(tokenizer.special_tokens_map.values())

[100, 102, 0, 101, 103]

In [40]:
inputs['input_ids'][0,selection] = 103
inputs
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"

In [41]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven [MASK] states declared [MASK] secession from the country to form the confederacy . [MASK] broke out in [MASK] 1861 [MASK] secession [MASK] forces attacked fort sum ##ter in south carolina , [MASK] over a month after lincoln ' s inauguration . [SEP]"

## 3. forward and calculate loss

In [42]:
mlm.eval()
with torch.no_grad():
    outputs = mlm(**inputs)

outputs.keys()

odict_keys(['loss', 'logits', 'hidden_states'])

In [44]:
outputs['loss'],outputs['logits'].shape

(tensor(0.6899), torch.Size([1, 62, 30522]))

In [46]:
len(outputs['hidden_states'])

13

In [47]:
outputs['hidden_states'][-1].shape

torch.Size([1, 62, 768])

## 4. from scratch
- layerNorm: $y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$

In [48]:
mlm.cls

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (transform_act_fn): GELUActivation()
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=768, out_features=30522, bias=True)
  )
)

In [50]:
mlm.eval()
with torch.no_grad():
    transformed = mlm.cls.predictions.transform(outputs['hidden_states'][-1])
    print(transformed.shape)
    logits = mlm.cls.predictions.decoder(transformed)
    print(logits.shape)
logits

torch.Size([1, 62, 768])
torch.Size([1, 62, 30522])


tensor([[[ -7.1853,  -7.1001,  -7.1574,  ...,  -6.3820,  -6.2823,  -4.3645],
         [-12.0900, -11.9610, -12.1337,  ..., -11.2234, -10.7172,  -8.6875],
         [ -6.2009,  -6.3899,  -5.8585,  ...,  -6.1539,  -6.1230,  -5.0137],
         ...,
         [ -1.7387,  -1.6272,  -1.6480,  ...,  -1.1015,  -0.7714,  -7.5605],
         [-14.1621, -14.0979, -14.1320,  ..., -11.0793, -11.3748,  -9.6525],
         [-11.0845, -11.5418, -11.0382,  ..., -10.7088,  -8.4698,  -9.3537]]])

In [51]:
outputs['logits']

tensor([[[ -7.1853,  -7.1001,  -7.1574,  ...,  -6.3820,  -6.2823,  -4.3645],
         [-12.0900, -11.9610, -12.1337,  ..., -11.2234, -10.7172,  -8.6875],
         [ -6.2009,  -6.3899,  -5.8585,  ...,  -6.1539,  -6.1230,  -5.0137],
         ...,
         [ -1.7387,  -1.6272,  -1.6480,  ...,  -1.1015,  -0.7714,  -7.5605],
         [-14.1621, -14.0979, -14.1320,  ..., -11.0793, -11.3748,  -9.6525],
         [-11.0845, -11.5418, -11.0382,  ..., -10.7088,  -8.4698,  -9.3537]]])

## 5. loss calculation and translation

In [52]:
outputs['loss']

tensor(0.6899)

In [53]:
ce = nn.CrossEntropyLoss()

In [56]:
inputs['labels'][0].view(-1).shape

torch.Size([62])

In [57]:
logits.shape

torch.Size([1, 62, 30522])

In [58]:
ce(logits[0], inputs['labels'][0].view(-1))

tensor(0.6899)

In [59]:
torch.argmax(logits[0], dim=1)

tensor([ 1012,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
         2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
         2670,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
         1996, 18179,  1012,  4808,  3631,  2041,  1999,  2285,  6863,  2043,
        22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
         1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
         1012,  1055])

In [60]:
' '.join(tokenizer.convert_ids_to_tokens(torch.argmax(logits[0],dim=1)))

". after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven southern states declared their secession from the country to form the confederacy . violence broke out in december 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . s"

In [61]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven [MASK] states declared [MASK] secession from the country to form the confederacy . [MASK] broke out in [MASK] 1861 [MASK] secession [MASK] forces attacked fort sum ##ter in south carolina , [MASK] over a month after lincoln ' s inauguration . [SEP]"

In [62]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"