# Building BERT with Pytorch from scratch

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
device = 'cuda'

### Prepare the dataset
######To prepare dataset, we do next:
###### - Split dataset on sentences
###### - Create vocabulary for word - token pair, for example {'go': 45}
###### - Create training dataset
###### - Add special tokens to the sentence
###### - Mask 15% of words in the sentence
###### - Pad sentence to predefined length
###### - Create NSP item from two sentences

In [4]:
from torch.utils.data import Dataset
import pandas as pd
from torchtext.data import get_tokenizer
from torchtext.vocab import vocab
from collections import Counter
import numpy as np
import typing
from tqdm import tqdm
import random
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as f

In [8]:
class IMDBBertDataset(Dataset):
    # Define Special tokens as attributes of class
    CLS = '[CLS]'
    PAD = '[PAD]'
    SEP = '[SEP]'
    MASK = '[MASK]'
    UNK = '[UNK]'

    MASK_PERCENTAGE = 0.15  # How much words to mask

    MASKED_INDICES_COLUMN = 'masked_indices'
    TARGET_COLUMN = 'indices'
    NSP_TARGET_COLUMN = 'is_next'
    TOKEN_MASK_COLUMN = 'token_mask'

    OPTIMAL_LENGTH_PERCENTILE = 70

    def __init__(self, path, ds_from=None, ds_to=None, should_include_text=False):
        self.ds: pd.Series = pd.read_csv(path, engine='python')['review']

        if ds_from is not None or ds_to is not None:
            self.ds = self.ds[ds_from:ds_to]

        self.tokenizer = get_tokenizer('basic_english')
        self.counter = Counter()
        self.vocab = None

        self.optimal_sentence_length = None
        self.should_include_text = should_include_text

        if should_include_text:
            self.columns = ['masked_sentence', self.MASKED_INDICES_COLUMN, 'sentence', self.TARGET_COLUMN,
                            self.TOKEN_MASK_COLUMN,
                            self.NSP_TARGET_COLUMN]
        else:
            self.columns = [self.MASKED_INDICES_COLUMN, self.TARGET_COLUMN, self.TOKEN_MASK_COLUMN,
                            self.NSP_TARGET_COLUMN]
        self.df = self.prepare_dataset()

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]

        inp = torch.Tensor(item[self.MASKED_INDICES_COLUMN]).long()
        token_mask = torch.Tensor(item[self.TOKEN_MASK_COLUMN]).bool()

        attention_mask = (inp == self.vocab[self.PAD]).unsqueeze(0)

        # NSP target
        if item[self.NSP_TARGET_COLUMN] == 0:
            t = [1, 0]
        else:
            t = [0, 1]
        nsp_target = torch.Tensor(t)

        # MLM target
        mask_target = torch.Tensor(item[self.TARGET_COLUMN]).long()  
        mask_target = mask_target.masked_fill_(token_mask, 0)
        return inp, attention_mask, token_mask, mask_target, nsp_target

    def prepare_dataset(self) -> pd.DataFrame:
        sentences = []
        nsp = []
        sentence_lens = []

        # Split sentences from dataset: 
        
        for review in self.ds:
            review_sentences = review.split(".")
            sentences += review_sentences
            self._update_length(review_sentences, sentence_lens)
        self.optimal_sentence_length = self._find_optimal_sentence_length(sentence_lens)
        
        # Create vocabulary:

        print("Create vocabulary")  
        for sentence in tqdm(sentences):  
            s = self.tokenizer(sentence)  
            self.counter.update(s)  
        self._fill_vocab()

        # Create training dataset:

        print("Preprocessing dataset")
        for review in tqdm(self.ds):
            review_sentences = review.split('.')
            for i in range(len(review_sentences)-1):
                # True NSP item
                first, second = self.tokenizer(review_sentences[i]), self.tokenizer(review_sentences[i+1])
                nsp.append(self._create_item(first, second, 1))

                # False NSP item
                first, second = self._select_false_nsp_sentences(sentences)
                first, second = self.tokenizer(first), self.tokenizer(second)
                nsp.append(self._create_item(first, second, 0))
        df = pd.DataFrame(nsp, columns=self.columns)
        return df

    def _update_length(self, review_sentences, sentence_lens):
        for word in review_sentences:
            sentence_lens.append(len(word)) 
    
    def _find_optimal_sentence_length(self, lengths: typing.List[int]):  
        arr = np.array(lengths)  
        return int(np.percentile(arr, self.OPTIMAL_LENGTH_PERCENTILE))
    
    def _fill_vocab(self):  
        # specials= argument is only in 0.12.0 version  
        # specials=[self.CLS, self.PAD, self.MASK, self.SEP, self.UNK]
        self.vocab = vocab(self.counter, min_freq=2)  
        # 0.11.0 uses this approach to insert specials  
        self.vocab.insert_token(self.CLS, 0)  
        self.vocab.insert_token(self.PAD, 1)  
        self.vocab.insert_token(self.MASK, 2)  
        self.vocab.insert_token(self.SEP, 3)  
        self.vocab.insert_token(self.UNK, 4)  
        self.vocab.set_default_index(4)

    def _create_item(self, first: typing.List[str], second: typing.List[str], target: int = 1):  
        # Create masked sentence item  
        updated_first, first_mask = self._preprocess_sentence(first.copy())  
        updated_second, second_mask = self._preprocess_sentence(second.copy())
        nsp_sentence = updated_first + [self.SEP] + updated_second  
        nsp_indices = self.vocab.lookup_indices(nsp_sentence)  
        inverse_token_mask = first_mask + [True] + second_mask

        # Create sentence item without masking random words  
        first, _ = self._preprocess_sentence(first.copy(), should_mask=False)  
        second, _ = self._preprocess_sentence(second.copy(), should_mask=False)  
        original_nsp_sentence = first + [self.SEP] + second  
        original_nsp_indices = self.vocab.lookup_indices(original_nsp_sentence)

        if self.should_include_text:
            return [nsp_sentence, nsp_indices, original_nsp_sentence, original_nsp_indices, inverse_token_mask, target]
        else:
            return [nsp_indices, original_nsp_indices, inverse_token_mask, target]

    def _select_false_nsp_sentences(self, sentences):
        return random.choice(sentences), random.choice(sentences)

    def _preprocess_sentence(self, sentence, should_mask = True):
        inverse_token_mask = [True for _ in range(max(len(sentence), self.optimal_sentence_length))]
        if should_mask:
            sentence, inverse_token_mask = self._mask_sentence(sentence)
        return self._pad_sentence(sentence, inverse_token_mask)

    # Step 1: Mask sentence

    def _mask_sentence(self, sentence: typing.List[str]):  
        len_s = len(sentence)  
        inverse_token_mask = [True for _ in range(max(len_s, self.optimal_sentence_length))]  
    
        mask_amount = round(len_s * self.MASK_PERCENTAGE)  
        for _ in range(mask_amount):  
            i = random.randint(0, len_s - 1) 
            j = random.randint(5, len(self.vocab)-1) 
    
            if random.random() < 0.8:  
                sentence[i] = self.MASK  
            else:
                sentence[i] = self.vocab.lookup_token(j)  
            inverse_token_mask[i] = False  
        
        return sentence, inverse_token_mask
    
    # Step 2:Preprocessing: [CLS] and [PAD] sentence

    def _pad_sentence(self, sentence: typing.List[str], inverse_token_mask: typing.List[bool] = None):  
        len_s = len(sentence)  
    
        if len_s >= self.optimal_sentence_length:  
            s = sentence[:self.optimal_sentence_length]  
        else:  
            s = sentence + [self.PAD] * (self.optimal_sentence_length - len_s)  
    
        # inverse token mask should be padded as well  
        if inverse_token_mask:  
            len_m = len(inverse_token_mask)  
            if len_m >= self.optimal_sentence_length:  
                inverse_token_mask = inverse_token_mask[:self.optimal_sentence_length]  
            else:  
                inverse_token_mask = inverse_token_mask + [True] * (self.optimal_sentence_length - len_m)  
        return s, inverse_token_mask

In [9]:
%%time
mydataset = IMDBBertDataset('/content/drive/MyDrive/BERT with Pytorch from Scratch/data/IMDB Dataset.csv', ds_from=0, ds_to=500)

Create vocabulary


100%|██████████| 7123/7123 [00:00<00:00, 61517.36it/s]


Preprocessing dataset


100%|██████████| 500/500 [00:01<00:00, 294.65it/s]

CPU times: user 2.92 s, sys: 160 ms, total: 3.08 s
Wall time: 3.11 s





In [10]:
mydataset.df

Unnamed: 0,masked_indices,indices,token_mask,is_next
0,"[5, 6, 7, 8, 9, 2, 11, 2, 13, 14, 15, 16, 17, ...","[5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17...","[True, True, True, True, True, False, True, Fa...",1
1,"[2, 20, 211, 2, 48, 7, 320, 194, 56, 589, 27, ...","[292, 20, 211, 1949, 48, 7, 320, 194, 56, 589,...","[False, True, True, False, True, True, True, T...",0
2,"[24, 25, 26, 27, 2, 29, 30, 1978, 32, 33, 34, ...","[24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 3...","[True, True, True, True, False, True, True, Fa...",1
3,"[1306, 2186, 199, 2, 54, 1002, 2236, 142, 2, 4...","[1306, 2186, 199, 402, 54, 1002, 2236, 142, 46...","[True, True, True, False, True, True, True, Tr...",0
4,"[7, 36, 37, 12, 2, 35, 39, 2, 40, 41, 4, 2, 43...","[7, 36, 37, 12, 38, 35, 39, 17, 40, 41, 4, 42,...","[True, True, True, True, False, True, True, Fa...",1
...,...,...,...,...
13241,"[29, 320, 465, 7, 4, 3914, 2642, 6, 7, 3898, 1...","[29, 320, 465, 7, 4, 3914, 2642, 6, 7, 3898, 1...","[True, True, True, True, True, True, True, Tru...",0
13242,"[32, 101, 2, 310, 63, 109, 39, 4031, 3090, 20,...","[32, 101, 30, 310, 63, 109, 39, 54, 3090, 20, ...","[True, True, False, True, True, True, True, Fa...",1
13243,"[692, 2, 339, 20, 328, 2369, 24, 20, 2, 86, 22...","[692, 170, 339, 20, 328, 2369, 24, 20, 21, 86,...","[True, False, True, True, True, True, True, Tr...",0
13244,"[527, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...","[527, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...","[True, True, True, True, True, True, True, Tru...",1


### Create sample input tensor

In [11]:
from torch.utils.data import DataLoader
mydataloader = DataLoader(mydataset, batch_size=5, shuffle=True)

In [12]:
item = next(iter(mydataloader))
print(item)

[tensor([[   2, 5074,   30,  ...,    1,    1,    1],
        [  86,    2,    2,  ...,    1,    1,    1],
        [ 722,   27, 1105,  ...,    1,    1,    1],
        [   2,  240,   27,  ...,    1,    1,    1],
        [  29,  320,   30,  ...,    1,    1,    1]]), tensor([[[False, False, False,  ...,  True,  True,  True]],

        [[False, False, False,  ...,  True,  True,  True]],

        [[False, False, False,  ...,  True,  True,  True]],

        [[False, False, False,  ...,  True,  True,  True]],

        [[False, False, False,  ...,  True,  True,  True]]]), tensor([[False,  True,  True,  ...,  True,  True,  True],
        [ True, False, False,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]]), tensor([[5346,    0,    0,  ...,    0,    0,    0],
        [   0,   19,  202,  ...,    0,    0,    0],
        [   0,    0,    0,  ...,   

In [13]:
batch_input, batch_attention_mask, inverse_token_mask, token_target, nsp_target = item
batch_input = batch_input.to(device)
batch_attention_mask = batch_attention_mask.to(device)
inverse_token_mask = inverse_token_mask.to(device)
token_target = token_target.to(device)
nsp_target = nsp_target.to(device)

###Build pyTorch model

##### Joint Embedding
###### - Token embedding: is used to encode word tokens. 
###### - Segment embedding: encodes belonging to the first or to the second sentence. We preprocess input sequence the next way: if the token belongs to the first sentence, set 0, otherwise set 1.
###### - Position embedding: encodes the position of the word in the sentence (using periodic functions to encode positions)

In [16]:
class JointEmbedding(nn.Module):

    def __init__(self, vocab_size, size):
        super(JointEmbedding, self).__init__()

        self.size = size

        self.token_emb = nn.Embedding(vocab_size, size)
        self.segment_emb = nn.Embedding(2, size)

        self.norm = nn.LayerNorm(size)

    def forward(self, input_tensor):
        sentence_size = input_tensor.size(-1)
        pos_tensor = self.attention_position(self.size, input_tensor)

        segment_tensor = torch.zeros_like(input_tensor).to(input_tensor.device)
        segment_tensor[:, sentence_size // 2 + 1:] = 1

        output = self.token_emb(input_tensor) + self.segment_emb(segment_tensor) + pos_tensor
        return self.norm(output)

    def attention_position(self, dim, input_tensor):
        batch_size = input_tensor.size(0)
        sentence_size = input_tensor.size(-1)

        pos = torch.arange(sentence_size, dtype=torch.long).to(input_tensor.device)
        d = torch.arange(dim, dtype=torch.long).to(input_tensor.device)
        d = (2 * d / dim)

        pos = pos.unsqueeze(1)
        pos = pos / (1e4 ** d)

        pos[:, ::2] = torch.sin(pos[:, ::2])
        pos[:, 1::2] = torch.cos(pos[:, 1::2])

        return pos.expand(batch_size, *pos.size())

    def numeric_position(self, dim, input_tensor):
        pos_tensor = torch.arange(dim, dtype=torch.long).to(input_tensor.device)
        return pos_tensor.expand_as(input_tensor)

In [17]:
jointemb = JointEmbedding(len(mydataset.vocab), 128)
jointemb.to(device)

JointEmbedding(
  (token_emb): Embedding(6314, 128)
  (segment_emb): Embedding(2, 128)
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)

In [18]:
input_tensor = batch_input

In [19]:
x = jointemb(input_tensor)

In [20]:
x.shape

torch.Size([5, 241, 128])

###Attention Head
######(the heart of Transformer)

In [21]:
class AttentionHead(nn.Module):  
  
    def __init__(self, dim_inp, dim_out):  
        super(AttentionHead, self).__init__()  
  
        self.dim_inp = dim_inp  
  
        self.q = nn.Linear(dim_inp, dim_out)  
        self.k = nn.Linear(dim_inp, dim_out)  
        self.v = nn.Linear(dim_inp, dim_out)  
  
    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor = None):  
        query, key, value = self.q(input_tensor), self.k(input_tensor), self.v(input_tensor)  
  
        scale = query.size(1) ** 0.5  
        scores = torch.bmm(query, key.transpose(1, 2)) / scale  
  
        scores = scores.masked_fill_(attention_mask, -1e9)  
        attn = f.softmax(scores, dim=-1)  
        context = torch.bmm(attn, value)  
  
        return context

In [22]:
myAttentionHead = AttentionHead(128,512)
myAttentionHead.to(device)

AttentionHead(
  (q): Linear(in_features=128, out_features=512, bias=True)
  (k): Linear(in_features=128, out_features=512, bias=True)
  (v): Linear(in_features=128, out_features=512, bias=True)
)

In [23]:
query, key, value = myAttentionHead.q(x), myAttentionHead.k(x), myAttentionHead.v(x)  

scale = query.size(1) ** 0.5  
scores = torch.bmm(query, key.transpose(1, 2)) / scale  

scores = scores.masked_fill_(batch_attention_mask, -1e9)  
attn = f.softmax(scores, dim=-1)  
context = torch.bmm(attn, value) 
 

In [24]:
scores = torch.bmm(query, key.transpose(1, 2)) / scale
scores.shape

torch.Size([5, 241, 241])

In [25]:
scores = scores.masked_fill_(batch_attention_mask, -1e9)
scores[0]

tensor([[ 1.4754e-01, -1.5687e-01,  2.4132e-01,  ..., -1.0000e+09,
         -1.0000e+09, -1.0000e+09],
        [ 1.9747e-01, -4.8152e-02, -5.7293e-02,  ..., -1.0000e+09,
         -1.0000e+09, -1.0000e+09],
        [ 2.6474e-01,  3.3935e-01,  4.9777e-01,  ..., -1.0000e+09,
         -1.0000e+09, -1.0000e+09],
        ...,
        [ 7.2873e-01, -9.4207e-02,  7.5281e-01,  ..., -1.0000e+09,
         -1.0000e+09, -1.0000e+09],
        [ 6.6566e-01, -9.5213e-02,  7.1552e-01,  ..., -1.0000e+09,
         -1.0000e+09, -1.0000e+09],
        [ 6.1017e-01, -1.3832e-01,  6.8698e-01,  ..., -1.0000e+09,
         -1.0000e+09, -1.0000e+09]], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [26]:
attn = f.softmax(scores, dim=-1) 
attn.shape

torch.Size([5, 241, 241])

###MultiHead Attention

######Single attention layer (head) is restricted to learn only the information from one particular subspace. Multi-head attention is the set of parallel attention heads that learns to retrieve the information from different representations. You may look on them as on filters in Convolutional Neural Networks.

In [27]:
class MultiHeadAttention(nn.Module):  
  
    def __init__(self, num_heads, dim_inp, dim_out):  
        super(MultiHeadAttention, self).__init__()  
  
        self.heads = nn.ModuleList([  
            AttentionHead(dim_inp, dim_out) for _ in range(num_heads)  
        ])  
        self.linear = nn.Linear(dim_out * num_heads, dim_inp)  
        self.norm = nn.LayerNorm(dim_inp)  
  
    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):  
        s = [head(input_tensor, attention_mask) for head in self.heads]  
        scores = torch.cat(s, dim=-1)  
        scores = self.linear(scores)  
        return self.norm(scores)

In [28]:
myMutiHeadAttention = MultiHeadAttention(2, 128, 512)
myMutiHeadAttention.to(device)

MultiHeadAttention(
  (heads): ModuleList(
    (0): AttentionHead(
      (q): Linear(in_features=128, out_features=512, bias=True)
      (k): Linear(in_features=128, out_features=512, bias=True)
      (v): Linear(in_features=128, out_features=512, bias=True)
    )
    (1): AttentionHead(
      (q): Linear(in_features=128, out_features=512, bias=True)
      (k): Linear(in_features=128, out_features=512, bias=True)
      (v): Linear(in_features=128, out_features=512, bias=True)
    )
  )
  (linear): Linear(in_features=1024, out_features=128, bias=True)
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)

In [29]:
s = [head(x, batch_attention_mask) for head in myMutiHeadAttention.heads]
scores = torch.cat(s, dim=-1)
scores.shape 

torch.Size([5, 241, 1024])

###Encoder
###### (For simplicity, we use only one layer)


In [30]:
class Encoder(nn.Module):  
  
    def __init__(self, dim_inp, dim_out, attention_heads=2, dropout=0.1):  
        super(Encoder, self).__init__()  
  
        self.attention = MultiHeadAttention(attention_heads, dim_inp, dim_out) 
        self.feed_forward = nn.Sequential(  
            nn.Linear(dim_inp, dim_out),  
            nn.Dropout(dropout),  
            nn.GELU(),  
            nn.Linear(dim_out, dim_inp),  
            nn.Dropout(dropout)  
        )
        self.norm = nn.LayerNorm(dim_inp)  
  
    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):  
        context = self.attention(input_tensor, attention_mask)  
        res = self.feed_forward(context)  
        return self.norm(res)

In [31]:
myEncoder = Encoder(128, 512)
myEncoder.to(device)

Encoder(
  (attention): MultiHeadAttention(
    (heads): ModuleList(
      (0): AttentionHead(
        (q): Linear(in_features=128, out_features=512, bias=True)
        (k): Linear(in_features=128, out_features=512, bias=True)
        (v): Linear(in_features=128, out_features=512, bias=True)
      )
      (1): AttentionHead(
        (q): Linear(in_features=128, out_features=512, bias=True)
        (k): Linear(in_features=128, out_features=512, bias=True)
        (v): Linear(in_features=128, out_features=512, bias=True)
      )
    )
    (linear): Linear(in_features=1024, out_features=128, bias=True)
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (feed_forward): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): Dropout(p=0.1, inplace=False)
    (2): GELU(approximate='none')
    (3): Linear(in_features=512, out_features=128, bias=True)
    (4): Dropout(p=0.1, inplace=False)
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_af

In [32]:
myEncoder(x, batch_attention_mask).shape

torch.Size([5, 241, 128])

###BERT
######BERT module is a container that combines all the modules together and returns the output.

In [33]:
class BERT(nn.Module):  
  
    def __init__(self, vocab_size, dim_inp, dim_out, attention_heads):  
        super(BERT, self).__init__()  
  
        self.embedding = JointEmbedding(vocab_size, dim_inp)  
        self.encoder = Encoder(dim_inp, dim_out, attention_heads)  
  
        self.token_prediction_layer = nn.Linear(dim_inp, vocab_size)  
        self.softmax = nn.LogSoftmax(dim=-1)  
        self.classification_layer = nn.Linear(dim_inp, 2)  
  
    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):  
        embedded = self.embedding(input_tensor)  
        encoded = self.encoder(embedded, attention_mask)  
  
        token_predictions = self.token_prediction_layer(encoded)  
  
        first_word = encoded[:, 0, :]  
        return self.softmax(token_predictions), self.classification_layer(first_word)

In [34]:
myBERT = BERT(len(mydataset.vocab), 128, 512, 2)
myBERT.to(device)

BERT(
  (embedding): JointEmbedding(
    (token_emb): Embedding(6314, 128)
    (segment_emb): Embedding(2, 128)
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (encoder): Encoder(
    (attention): MultiHeadAttention(
      (heads): ModuleList(
        (0): AttentionHead(
          (q): Linear(in_features=128, out_features=512, bias=True)
          (k): Linear(in_features=128, out_features=512, bias=True)
          (v): Linear(in_features=128, out_features=512, bias=True)
        )
        (1): AttentionHead(
          (q): Linear(in_features=128, out_features=512, bias=True)
          (k): Linear(in_features=128, out_features=512, bias=True)
          (v): Linear(in_features=128, out_features=512, bias=True)
        )
      )
      (linear): Linear(in_features=1024, out_features=128, bias=True)
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (feed_forward): Sequential(
      (0): Linear(in_features=128, out_features=512, bias=True)


In [35]:
token_predictions, first_word = myBERT(input_tensor, batch_attention_mask)

In [36]:
token_predictions.shape

torch.Size([5, 241, 6314])

In [37]:
first_word.shape

torch.Size([5, 2])

###Train the model


In [39]:
from torch.utils.tensorboard import SummaryWriter
import time
from pathlib import Path
import os

In [40]:
class BertTrainer:

    def __init__(self,
                 model: BERT,
                 dataset: IMDBBertDataset,
                 log_dir: Path,
                 checkpoint_dir: Path = None,
                 print_progress_every: int = 50,
                 batch_size: int = 24,
                 learning_rate: float = 0.005,
                 epochs: int = 5,
                 device: str = 'cpu',
                 ):
        self.model = model
        self.dataset = dataset
        self.device = device

        self.batch_size = batch_size
        self.epochs = epochs
        self.current_epoch = 0

        self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)

        self.writer = SummaryWriter(str(log_dir))
        self.checkpoint_dir = Path(checkpoint_dir)
        self._print_every = print_progress_every

        self.criterion = nn.BCEWithLogitsLoss().to(self.device)
        self.ml_criterion = nn.NLLLoss(ignore_index=0).to(self.device)
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.015)

    def train(self, epoch: int):
        print(f"Begin epoch {epoch}")

        prev = time.time()
        average_nsp_loss = 0
        average_mlm_loss = 0
        for i, value in enumerate(self.loader):
            index = i + 1
            inp, mask, inverse_token_mask, token_target, nsp_target = value
            inp = inp.to(self.device)
            mask = mask.to(self.device)
            inverse_token_mask = inverse_token_mask.to(self.device)
            token_target = token_target.to(self.device)
            nsp_target = nsp_target.to(self.device)

            self.optimizer.zero_grad()

            token, nsp = self.model(inp, mask)

            tm = inverse_token_mask.unsqueeze(-1).expand_as(token)
            token = token.masked_fill(tm, 0)

            loss_token = self.ml_criterion(token.transpose(1, 2), token_target)
            loss_nsp = self.criterion(nsp, nsp_target)

            loss = loss_token + loss_nsp
            average_nsp_loss += loss_nsp
            average_mlm_loss += loss_token

            loss.backward()
            self.optimizer.step()

            if index % self._print_every == 0:
                elapsed = time.gmtime(time.time() - prev)

                log_nsp_loss = average_nsp_loss / self._print_every
                log_mlm_loss = average_mlm_loss / self._print_every
                log_nsp_acc = 100*(nsp.argmax(1) == nsp_target.argmax(1)).sum() / nsp.size(0)
                log_mlm_acc = 100*(token.argmax(-1).masked_select(~inverse_token_mask) == token_target.masked_select(~inverse_token_mask)).sum() / (token.size(0) * token.size(1))

                print(f"{time.strftime('%H:%M:%S', elapsed)} | Epoch {epoch} | Step {index}/{len(self.loader)} | "
                      f"NSP Loss: {log_nsp_loss:.2f} | MLM Loss: {log_mlm_loss:.2f} | NSP Accuracy: {log_nsp_acc:.2f}% | MLM Accuracy: {log_mlm_acc:.2f}%")

                global_step = index + epoch*len(self.loader)
                self.writer.add_scalar("NSP loss", log_nsp_loss, global_step=global_step)
                self.writer.add_scalar("MLM loss", log_mlm_loss, global_step=global_step)
                self.writer.add_scalar("NSP accuracy", log_nsp_acc, global_step=global_step)
                self.writer.add_scalar("Token accuracy", log_mlm_acc, global_step=global_step)

                average_nsp_loss = 0
                average_mlm_loss = 0
        return loss

    def __call__(self):
        if self.checkpoint_dir and os.path.exists(self.checkpoint_dir.joinpath("checkpoint_last.txt")):
            with open(self.checkpoint_dir.joinpath("checkpoint_last.txt")) as f:
                name = f.readline().strip()
            self.load_checkpoint(self.checkpoint_dir.joinpath(name))
            start_epoch = self.current_epoch + 1
        else:
            start_epoch = 0

        for self.current_epoch in range(start_epoch, self.epochs):
            loss = self.train(self.current_epoch)
            self.save_checkpoint(epoch=self.current_epoch, loss=loss)

    def save_checkpoint(self, epoch, loss):
        if not self.checkpoint_dir:
            return

        prev = time.time()
        name = f"checkpoint_epoch{epoch}.pt"
        print(f"Saving model checkpoint epoch {epoch} to {self.checkpoint_dir.joinpath(name)}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': loss,
        }, self.checkpoint_dir.joinpath(name))

        with open(self.checkpoint_dir.joinpath("checkpoint_last.txt"), "w") as f:
            f.write(name)

    def load_checkpoint(self, path: Path):
        print(f"Loading model checkpoint from {path}")
        checkpoint = torch.load(path)
        self.current_epoch = checkpoint['epoch']
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Model loaded at epoch {self.current_epoch}.")

In [48]:
mytrainer = BertTrainer(myBERT, mydataset, log_dir='./checkpoints/logs', checkpoint_dir='./checkpoints',
                        epochs=5, device=device)

In [49]:
mytrainer()

Begin epoch 0
00:00:02 | Epoch 0 | Step 50/552 | NSP Loss: 0.70 | MLM Loss: 6.40 | NSP Accuracy: 41.67% | MLM Accuracy: 0.05%
00:00:04 | Epoch 0 | Step 100/552 | NSP Loss: 0.70 | MLM Loss: 6.40 | NSP Accuracy: 41.67% | MLM Accuracy: 0.07%
00:00:06 | Epoch 0 | Step 150/552 | NSP Loss: 0.69 | MLM Loss: 6.39 | NSP Accuracy: 45.83% | MLM Accuracy: 0.12%
00:00:08 | Epoch 0 | Step 200/552 | NSP Loss: 0.70 | MLM Loss: 6.34 | NSP Accuracy: 25.00% | MLM Accuracy: 0.10%
00:00:10 | Epoch 0 | Step 250/552 | NSP Loss: 0.70 | MLM Loss: 6.40 | NSP Accuracy: 37.50% | MLM Accuracy: 0.19%
00:00:12 | Epoch 0 | Step 300/552 | NSP Loss: 0.70 | MLM Loss: 6.36 | NSP Accuracy: 50.00% | MLM Accuracy: 0.10%
00:00:14 | Epoch 0 | Step 350/552 | NSP Loss: 0.69 | MLM Loss: 6.35 | NSP Accuracy: 54.17% | MLM Accuracy: 0.28%
00:00:16 | Epoch 0 | Step 400/552 | NSP Loss: 0.70 | MLM Loss: 6.41 | NSP Accuracy: 54.17% | MLM Accuracy: 0.07%
00:00:18 | Epoch 0 | Step 450/552 | NSP Loss: 0.70 | MLM Loss: 6.38 | NSP Accuracy: