# Training With MLM and NSP

Now we finally get to fine-tune our model using both MLM and NSP, using everything we've learned in this section so far.

Again, we'll be using *Meditations*, and the `BertForPreTraining` class. So we'll start by initializing all of these.

In [1]:
from transformers import BertTokenizer, BertForPreTraining
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')

with open('../../data/text/meditations/clean.txt', 'r') as fp:
    text = fp.read().split('\n')

Some weights of BertForPreTraining were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
text[:3]

['From my grandfather Verus I learned good morals and the government of my temper.',
 'From the reputation and remembrance of my father, modesty and a manly character.',
 'From my mother, piety and beneficence, and abstinence, not only from evil deeds, but even from evil thoughts; and further, simplicity in my way of living, far removed from the habits of the rich.']

We'll need to modify our data into a mix of non-random sentences, and random sentences for NSP. We do this using the same logic that we used when training just NSP, creating a *'bag'* of random sentences that we can pull from when selecting a random sentence B.

In [3]:
bag = [item for sentence in text for item in sentence.split('.') if item != '']
bag_size = len(bag)

And now we create our 50/50 NSP training data.

In [4]:
import random

sentence_a = []
sentence_b = []
label = []

for paragraph in text:
    sentences = [
        sentence for sentence in paragraph.split('.') if sentence != ''
    ]
    num_sentences = len(sentences)
    if num_sentences > 1:
        start = random.randint(0, num_sentences-2)
        # 50/50 whether is IsNextSentence or NotNextSentence
        if random.random() >= 0.5:
            # this is IsNextSentence
            sentence_a.append(sentences[start])
            sentence_b.append(sentences[start+1])
            label.append(0)
        else:
            index = random.randint(0, bag_size-1)
            # this is NotNextSentence
            sentence_a.append(sentences[start])
            sentence_b.append(bag[index])
            label.append(1)

We can now tokenize our data, using a truncation/padding `max_length` of *512*.

In [5]:
inputs = tokenizer(sentence_a, sentence_b, return_tensors='pt',
                   max_length=512, truncation=True, padding='max_length')

In [6]:
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

We can now create our *labels* tensor for MLM, and the *next_sentence_label* for NSP.

The *next_sentence_label* tensor is simply a `LongTensor` built from our `label` variable.

In [7]:
inputs['next_sentence_label'] = torch.LongTensor([label]).T

And the *labels* tensor is simply a clone of the *input_ids* tensor before masking.

In [8]:
inputs['labels'] = inputs.input_ids.detach().clone()

Now we mask tokens in the *input_ids* tensor using the 15% probability for MLM - ensuring we don't mask *CLS*, *SEP*, or *PAD* tokens.

In [9]:
# create random array of floats with equal dimensions to input_ids tensor
rand = torch.rand(inputs.input_ids.shape)
# create mask array
mask_arr = (rand < 0.15) * (inputs.input_ids != 101) * \
           (inputs.input_ids != 102) * (inputs.input_ids != 0)

And now take the indices of each `True` value within each vector.

In [10]:
selection = []

for i in range(inputs.input_ids.shape[0]):
    selection.append(
        torch.flatten(mask_arr[i].nonzero()).tolist()
    )

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:882.)
  torch.flatten(mask_arr[i].nonzero()).tolist()


Then apply these indices to each row in *input_ids*, assigning each value at these indices a value of *103*.

In [11]:
for i in range(inputs.input_ids.shape[0]):
    inputs.input_ids[i, selection[i]] = 103

In [12]:
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'next_sentence_label', 'labels'])

We now have all of our inputs and labels ready, so we can begin setting `inputs` up to be fed into our model during training. We create a PyTorch dataset from our data.

In [13]:
class MeditationsDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)

Initialize our data using the `MeditationsDataset` class.

In [14]:
dataset = MeditationsDataset(inputs)

And initialize the dataloader, which we'll be using to load our data into the model during training.

In [15]:
loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

Now we can move onto setting up the training loop. First we setup GPU/CPU usage.

In [16]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# and move our model over to the selected device
model.to(device)

BertForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine

Activate the training mode of our model, and initialize our optimizer (Adam with weighted decay - reduces chance of overfitting).

In [17]:
from transformers import AdamW

# activate training mode
model.train()
# initialize optimizer
optim = AdamW(model.parameters(), lr=5e-5)

Now we can move onto the training loop, we'll train for two epochs (change `epochs` to modify this).

In [18]:
from tqdm.notebook import tqdm  # for our progress bar

epochs = 2

for epoch in range(epochs):
    # setup loop with TQDM and dataloader
    loop = tqdm(loader, leave=True)
    for batch in loop:
        # initialize calculated gradients (from prev step)
        optim.zero_grad()
        # pull all tensor batches required for training
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        next_sentence_label = batch['next_sentence_label'].to(device)
        labels = batch['labels'].to(device)
        # process
        outputs = model(input_ids, attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                        next_sentence_label=next_sentence_label,
                        labels=labels)
        # extract loss
        loss = outputs.loss
        # calculate loss for every parameter that needs grad update
        loss.backward()
        # update parameters
        optim.step()
        # print relevant info to progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}





HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))




Easy, we've now fine-tuned our BERT using both NSP and MLM.