# How BERT is Trained

In this notebook we'll take a stroll through the steps we take when training the core BERT model. We'll be exploring how the two training approaches used for pretraining BERT actually work - Next Sentence Prediction (NSP) and Masked-Language Modelling (MLM).

First, let's import everything we need.

In [3]:
from transformers import BertTokenizer, BertForPreTraining
import torch

Next, we'll need to initialize our tokenizer and model, and tokenize a paragraph of text from the Wikipedia page on the American Civil War. Finally, we process these tokenized `inputs` through our initialized model to produce our model `outputs`.

In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')

text = ("After Abraham Lincoln won the November 1860 presidential [MASK] on an "
        "anti-slavery platform, an initial seven slave states declared their "
        "secession from the country to form the Confederacy. War broke out in "
        "April 1861 when secessionist forces [MASK] Fort Sumter in South "
        "Carolina, just over a month after Lincoln's inauguration.")

inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)

Some weights of BertForPreTraining were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


We're using the `BertForPreTraining` class, which gives us two outputs:

In [8]:
outputs.keys()

odict_keys(['prediction_logits', 'seq_relationship_logits'])

In [9]:
outputs.prediction_logits

tensor([[[ -7.6192,  -7.5433,  -7.6124,  ...,  -6.7155,  -6.7375,  -4.6122],
         [-12.5489, -12.3772, -12.6500,  ..., -11.8643, -11.4446,  -9.1151],
         [ -6.2346,  -6.3590,  -5.9091,  ...,  -6.1258,  -6.2720,  -5.0268],
         ...,
         [ -2.2497,  -2.1352,  -2.1812,  ...,  -1.7201,  -1.2728,  -7.8302],
         [-14.2654, -14.3100, -14.2294,  ..., -11.4669, -11.7212, -10.3129],
         [-11.5071, -12.0389, -11.6046,  ..., -11.2875,  -9.1655,  -9.1732]]],
       grad_fn=<AddBackward0>)

There are 62 tokens (60 + \[CLS\] and \[SEP\]), we can see this reflected in the `prediction_logits.shape`:

In [10]:
outputs.prediction_logits.shape

torch.Size([1, 62, 30522])

In [11]:
outputs.seq_relationship_logits

tensor([[ 2.8256, -1.6897]], grad_fn=<AddmmBackward>)

* `outputs.prediction_logits` is the output from the MLM head (vocab which maps to a word from the vocab after *softmax*)

* `outputs.seq_relationship_logits` is the output from the NSP head (IsNext/NotNext, 0/1, as to whether it is the next sentence or not)

## Masked Language Modelling

First, let's convert our `prediction_logits` into token predictions. To do this, we'll need to get a mapping between index values and words from the model vocab, which we can extract from the `tokenizer`.

In [12]:
token2idx = tokenizer.get_vocab()

Then we invert the dictionary to create an index to token dictionary.

In [13]:
idx2token = {value:key for key, value in token2idx.items()}

Now all we need to do is take `prediction_logits` where we had a **\[MASK\]** token and process it through a softmax function, followed by argmax, to get our index prediction. We don't know the exact index of our mask token right now, so let's first choose a random index, number **2**.

In [22]:
outputs.prediction_logits[0][2].shape

torch.Size([30522])

The shape here matches to our vocab size:

In [23]:
len(idx2token)

30522

Now all we need to do is take the softmax to get a probability distribution across the *30522* tokens, and extract the most probable using an argmax function:

In [24]:
softmax = torch.nn.functional.softmax(outputs.prediction_logits[0][2], dim=-1)  # create probability distribution
argmax = torch.argmax(softmax)  # get index of the max probability

In [25]:
argmax

tensor(8181)

Our predicted token is number *8181*, we can pass this to our `idx2token` dictionary to get the actual word from our vocabulary.

In [26]:
idx2token[argmax.item()]

'abraham'

Okay, our second input token was "Abraham" and we can see this reflected in the output predictions. Where we have not masked a word, we would expect the equivalent predicted output token to match (or closely match) the input. Let's try getting all predicted output tokens like so:

In [19]:
softmax = torch.nn.functional.softmax(outputs.prediction_logits[0], dim=1)  # create probability distribution
argmax = torch.argmax(softmax, dim=1)  # get index of the max probability

In [20]:
argmax

tensor([28191,  2348,  8181, 16628,  2180,  3882,  2281,  7313,  4883, 27419,
         2006,  2010,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
         8914,  2163,  4161,  2037,  4336,  2013,  1996,  2406,  2000,  2433,
        28775, 18179, 16363,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
        18232,  2923,  2749,  4548,  3481,  7680,  5017,  2005,  2148,  3792,
        24901,  2074,  2058,  1037,  3204,  2077,  3946,  1005,  1055, 17331,
         1025, 25656])

In [21]:
for i in argmax:
    print(idx2token[i.item()], end=' ')

##ecin although abraham lincolnshire won 1948 november 1860 presidential primaries on his anti - slavery platform , an initial seven tributary states declared their independence from the country to form ##ici confederacy ##yre war broke out in april 1861 when ##oya ##ist forces occupied fort sum ##mer for south carolina ##trip just over a month before grant ' s inauguration ; ##tson 

We can see here that the predicted word for *'election'* is *'primaries'*, which can is a reasonably close word match - although certainly not perfect or correct. For *'attacked'* we see *'occupied'* as the predicted word, again, not correct but pretty close.

## Next Sentence Prediction

Next sentence prediction is slightly different. First, we need to define the two sequences, which we must split using a **\[SEP\]** token and differentiate using the `token_type_ids` tensor.

In [28]:
text = ("After Abraham Lincoln won the November 1860 presidential [MASK] on an "
        "anti-slavery platform, an initial seven slave states declared their "
        "secession from the country to form the Confederacy.")
text2 = ("War broke out in April 1861 when secessionist forces [MASK] Fort "
         "Sumter in South Carolina, just over a month after Lincoln's "
         "inauguration.")

In [29]:
inputs = tokenizer(text, text2, return_tensors="pt")

In [32]:
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,   103,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,   102,  2162,  3631,  2041,  1999,  2258,  6863,
          2043, 22965,  2923,  2749,   103,  3481,  7680,  3334,  1999,  2148,
          3792,  1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055,
         17331,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In our `token_type_ids` tensor here we can see that there are **0** values where we have sentence A, followed by **1** values where we have sentence B. Additionally, in `input_ids`, we have the value **102** (the *SEP*erator token) seperating both tokens.

Both of these are done automatically by the tokenizer, and BERT relies on this when we work with multiple sequences, as we do for NSP.

Now we're ready to process the inputs through our model. Originally we output `outputs.seq_relationship_logits` as:

```
tensor([[ 2.8256, -1.6897]], grad_fn=<AddmmBackward>)
```

However, we hadn't setup our inputs correctly, so we should now see a change in these logits.

In [47]:
outputs = model(**inputs)

In [48]:
outputs.seq_relationship_logits

tensor([[ 6.0915, -5.6939]], grad_fn=<AddmmBackward>)

Great, now we process them through a argmax function to get 0/1  as to whether sentence B follows sentence A (marked by `0` in `token_type_ids`).

In [71]:
argmax = torch.argmax(outputs.seq_relationship_logits)  # get index of the max activation

In [74]:
argmax

tensor(0)

Index **0** represents BERTs *IsNext* class, meaning that sentence B *is the next* sentence after A. Index **1** represents the *NotNext* class, meaning sentence B is *not* the next sentence after B. We can write this as:

In [75]:
'NotNext' if argmax.item() else 'IsNext' 

'IsNext'

From this, we can see that our model is correctly identifying the two sentences as a pair.