In [1]:
import torch
from torch import nn
from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM

### 1. model load and data preprocessing

In [2]:
model_type = 'bert-base-uncased'

In [38]:
tokenizer = BertTokenizer.from_pretrained(model_type)
bert = BertModel.from_pretrained(model_type)
mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)

loading file vocab.txt from cache at C:\Users\kennyS\.cache\huggingface\hub\models--bert-base-uncased\snapshots\1dbc166cf8765166998eff31ade2eb64c8a40076\vocab.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at C:\Users\kennyS\.cache\huggingface\hub\models--bert-base-uncased\snapshots\1dbc166cf8765166998eff31ade2eb64c8a40076\tokenizer_config.json
loading file tokenizer.json from cache at C:\Users\kennyS\.cache\huggingface\hub\models--bert-base-uncased\snapshots\1dbc166cf8765166998eff31ade2eb64c8a40076\tokenizer.json
loading configuration file config.json from cache at C:\Users\kennyS\.cache\huggingface\hub\models--bert-base-uncased\snapshots\1dbc166cf8765166998eff31ade2eb64c8a40076\config.json
Model config BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "grad

In [4]:
text = ("After Abraham Lincoln won the November 1860 presidential"
        "election on an anti-slavery platform, an initial seven"
        "slave states declared their secession from the country "
        "to form the Confederacy. War broke out in April 1861 "
        "when secesstionist forces attacked Fort Sumter in South "
        "Carolina, just over a month after Lincoln's inauguration.")

In [5]:
text

"After Abraham Lincoln won the November 1860 presidentialelection on an anti-slavery platform, an initial sevenslave states declared their secession from the country to form the Confederacy. War broke out in April 1861 when secesstionist forces attacked Fort Sumter in South Carolina, just over a month after Lincoln's inauguration."

In [6]:
inputs = tokenizer(text, return_tensors='pt')

In [7]:
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883, 12260,
          7542,  2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,
         19463, 14973,  2063,  2163,  4161,  2037, 22965,  2013,  1996,  2406,
          2000,  2433,  1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,
          6863,  2043, 10819,  7971,  3508,  2923,  2749,  4457,  3481,  7680,
          3334,  1999,  2148,  3792,  1010,  2074,  2058,  1037,  3204,  2044,
          5367,  1005,  1055, 17331,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1,

In [8]:
inputs['input_ids'].shape

torch.Size([1, 66])

### 2. masking

In [9]:
inputs['labels'] = inputs['input_ids'].detach().clone()

In [10]:
inputs['labels']

tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883, 12260,
          7542,  2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,
         19463, 14973,  2063,  2163,  4161,  2037, 22965,  2013,  1996,  2406,
          2000,  2433,  1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,
          6863,  2043, 10819,  7971,  3508,  2923,  2749,  4457,  3481,  7680,
          3334,  1999,  2148,  3792,  1010,  2074,  2058,  1037,  3204,  2044,
          5367,  1005,  1055, 17331,  1012,   102]])

In [11]:
mask = torch.rand(inputs['input_ids'].shape) < 0.15

In [12]:
mask

tensor([[False, False, False, False, False, False, False, False,  True, False,
         False, False, False,  True, False, False, False,  True, False,  True,
         False, False,  True, False, False, False,  True, False,  True, False,
         False, False, False, False,  True, False,  True, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
          True, False,  True, False, False, False]])

In [13]:
sum(mask[0])

tensor(11)

In [14]:
11/66

0.16666666666666666

In [15]:
mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \
    * (inputs['input_ids'] != 101) \
    * (inputs['input_ids'] != 102)

In [16]:
mask_arr

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False,  True, False, False,
         False, False, False, False, False, False, False, False,  True, False,
         False, False, False, False, False,  True, False, False, False,  True,
         False, False, False, False, False, False, False, False, False, False,
         False,  True, False, False, False, False,  True, False, False, False,
         False,  True, False, False, False, False]])

In [17]:
sum(mask_arr[0])

tensor(7)

In [18]:
selection = torch.flatten(mask_arr[0].nonzero()).tolist()

In [19]:
selection

[17, 28, 35, 39, 51, 56, 61]

In [20]:
inputs['input_ids'][0, selection] = 103

In [21]:
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883, 12260,
          7542,  2006,  2019,  3424,  1011,  8864,  4132,   103,  2019,  3988,
         19463, 14973,  2063,  2163,  4161,  2037, 22965,  2013,   103,  2406,
          2000,  2433,  1996, 18179,  1012,   103,  3631,  2041,  1999,   103,
          6863,  2043, 10819,  7971,  3508,  2923,  2749,  4457,  3481,  7680,
          3334,   103,  2148,  3792,  1010,  2074,   103,  1037,  3204,  2044,
          5367,   103,  1055, 17331,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1,

### 3. forward and calculate loss

In [22]:
mlm.eval()
with torch.no_grad():
    output = mlm(**inputs)

In [23]:
output.keys()

odict_keys(['loss', 'logits', 'hidden_states'])

In [24]:
output.logits

tensor([[[ -7.0729,  -6.9823,  -7.0373,  ...,  -6.2476,  -6.2185,  -4.2436],
         [-12.4250, -12.2744, -12.3999,  ..., -11.5897, -11.1210,  -9.2143],
         [ -6.0864,  -6.2552,  -5.7428,  ...,  -6.1365,  -5.8874,  -4.5088],
         ...,
         [ -3.9977,  -4.0382,  -3.9518,  ...,  -2.9818,  -2.6752,  -8.1843],
         [-14.2224, -14.0590, -14.0979,  ..., -11.1725, -11.1214,  -9.1664],
         [-10.4361, -10.7536, -10.2923,  ...,  -9.9935,  -7.9635,  -7.9996]]])

In [25]:
output.loss

tensor(0.7957)

In [26]:
len(output['hidden_states'])

13

In [27]:
ce = nn.CrossEntropyLoss()

In [29]:
inputs['labels'][0].view(-1).shape

torch.Size([66])

In [30]:
last_hidden_state = output['hidden_states'][-1]

In [31]:
mlm.eval()
with torch.no_grad():
    transformed = mlm.cls.predictions.transform(last_hidden_state)
    logits = mlm.cls.predictions.decoder(transformed)
logits

tensor([[[ -7.0729,  -6.9823,  -7.0373,  ...,  -6.2476,  -6.2185,  -4.2436],
         [-12.4250, -12.2744, -12.3999,  ..., -11.5897, -11.1210,  -9.2143],
         [ -6.0864,  -6.2552,  -5.7428,  ...,  -6.1365,  -5.8874,  -4.5088],
         ...,
         [ -3.9977,  -4.0382,  -3.9518,  ...,  -2.9818,  -2.6752,  -8.1843],
         [-14.2224, -14.0590, -14.0979,  ..., -11.1725, -11.1214,  -9.1664],
         [-10.4361, -10.7536, -10.2923,  ...,  -9.9935,  -7.9635,  -7.9996]]])

In [32]:
ce(logits[0], inputs['labels'][0].view(-1))

tensor(0.7957)

In [33]:
torch.argmax(logits[0], dim=1)

tensor([ 1012,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2624,
         7542,  2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,
        19463, 14973,  2063,  2163,  4161,  2037, 22965,  2013,  1996,  2406,
         2000,  2433,  1996, 18179,  1012,  4808, 12591,  2162,  1999,  2285,
         6863,  2043, 10819,  7971,  3508,  2923,  2749,  4457,  3481,  7680,
        10907,  1999,  2148,  3792,  1010,  2074,  2058,  1037,  3204,  2044,
         5367,  1005,  1055, 17331,  1012,  1055])

In [39]:
' '.join(tokenizer.convert_ids_to_token(inputs['input_ids'][0]))

AttributeError: 'BertTokenizer' object has no attribute 'convert_ids_to_token'