# Train Using MLM

In this notebook we'll take a look at fine-tuning a model using masked-language modelling (MLM).

First we'll import all we need.

In [1]:
from transformers import BertTokenizer, BertForMaskedLM
import torch

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

text = ("After Abraham Lincoln won the November 1860 presidential [MASK] on an "
        "anti-slavery platform, an initial seven slave states declared their "
        "secession from the country to form the Confederacy. War broke out in "
        "April 1861 when secessionist forces [MASK] Fort Sumter in South "
        "Carolina, just over a month after Lincoln's inauguration.")

inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
outputs.keys()

odict_keys(['logits'])

This returns just our MLM output logits.

In [5]:
outputs.logits.shape

torch.Size([1, 62, 30522])

To identify the token position where we have **\[MASK\]** tokens we can check the `inputs` tensor for tokens matching *103* (eg MASK).

In [17]:
mask_pos = torch.flatten((inputs.input_ids[0] == 103).nonzero()).tolist()
mask_pos

[9, 43]

It is for these two positions that we must calculate the loss for when training our model. How does that work? Well, we compare the `inputs` at those two positions, to the predicted `outputs` at those two positions - converted to one-hot encoding and probability distribution respectively.

To convert the `inputs` tokens to one-hot encodings we need the vocab dictionary.

In [18]:
token2idx = tokenizer.get_vocab()

In [20]:
inputs.input_ids[0][mask_pos]

tensor([103, 103])

https://huggingface.co/transformers/training.html