# NSP training step

In [8]:
from transformers import BertTokenizer, BertForNextSentencePrediction
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForNextSentencePrediction: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
with open("../data/text/meditations/clean.txt", 'r') as f:
    text = f.read().split('\n')
text[:3]

['From my grandfather Verus I learned good morals and the government of my temper.',
 'From the reputation and remembrance of my father, modesty and a manly character.',
 'From my mother, piety and beneficence, and abstinence, not only from evil deeds, but even from evil thoughts; and further, simplicity in my way of living, far removed from the habits of the rich.']

In [5]:
# split text by "." 
bag = [sentence for para in text for sentence in para.split('.') if sentence != '']
bag_len = len(bag)
print(bag_len)

1372


In [6]:
import random
sentence_a = []
sentence_b = []
label = []

for paragraph in text:
    sentences = [sentence for sentence in paragraph.split('.') if sentence!='']
    num_sentences = len(sentences)
    if num_sentences>1:
        start = random.randint(0, num_sentences-2)
        if random.random()>=0.5:
            sentence_a.append(sentences[start])
            sentence_b.append(sentences[start+1])
            label.append(0)
        else:
            index = random.randint(0, bag_len-1)
            sentence_a.append(sentences[start])
            sentence_b.append(bag[index])
            label.append(1)

In [7]:
for i in range(3):
    print(label[i])
    print(sentence_a[i] + '\n---')
    print(sentence_b[i] + '\n')

1
From Maximus I learned self-government, and not to be led aside by anything; and cheerfulness in all circumstances, as well as in illness; and a just admixture in the moral character of sweetness and dignity, and to do what was set before me without complaining
---
Whatever of the things which are not within thy power thou shalt suppose to be good for thee or evil, it must of necessity be that, if such a bad thing befall thee or the loss of such a good thing, thou wilt blame the gods, and hate men too, those who are the cause of the misfortune or the loss, or those who are suspected of being likely to be the cause; and indeed we do much injustice, because we make a difference between these things

0
 His secrets were not but very few and very rare, and these only about public matters; and he showed prudence and economy in the exhibition of the public spectacles and the construction of public buildings, his donations to the people, and in such things, for he was a man who looked to wh

In [9]:
inputs = tokenizer(
            sentence_a, sentence_b, 
            return_tensors='pt',
            max_length=512, 
            truncation=True,
            padding='max_length')

inputs['labels'] = torch.LongTensor([label]).T

In [11]:
class NSPDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __len__(self):
        return len(self.encodings.input_ids)
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
dataset = NSPDataset(inputs)

In [12]:
loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

In [13]:
device = torch.device('mps')
model.to(device)
model.train()

BertForNextSentencePrediction(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [15]:
from tqdm import tqdm
epochs=2
optim = torch.optim.AdamW(model.parameters() ,lr=5e-5)

for epoch in range(epochs):
    loop = tqdm(loader, leave=True)
    for batch in loop:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
                    input_ids, 
                    attention_mask=attention_mask, 
                    token_type_ids=token_type_ids, 
                    labels=labels)
        
        loss = outputs.loss
        loss.backward()
        optim.step()

        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss = loss.item())

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Epoch 0:  10%|â–ˆ         | 2/20 [01:15<11:23, 38.00s/it, loss=2.06]


KeyboardInterrupt: 