In [1]:
import tqdm
import math
import random
import torch
import numpy as np
import torch.nn as nn

from torch.optim import Adam
import torch.nn.functional as F

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


## Data

In [2]:
class BERTDataset(Dataset):
    def __init__(self, corpus_path = './data/eng-fra.txt', vocab = None , seq_len = 20):
        self.vocab = vocab
        self.seq_len = seq_len
        self.corpus_path = corpus_path
        self.lines = []

        # Reopen the file to read the lines
        with open(self.corpus_path , "r", encoding="utf-8") as f:
            for line in tqdm.tqdm(f, desc="Loading Dataset"):
                self.lines.append(line.replace('\n', '').split('\t'))

        self.corpus_lines = len(self.lines)

    def get_corpus_line(self, item):
        return self.lines[item][0], self.lines[item][1]
    
    def get_random_line(self):
        return self.lines[random.randrange(self.corpus_lines)][1]
    
    def random_sent(self, index):
        t1, t2 = self.get_corpus_line(index)

        # output_text, label(isNotNext:0, isNext:1)
        if random.random() > 0.5:
            return t1, t2, 1
        else:
            return t1, self.get_random_line(), 0
        
    def random_word(self, sentence):
        tokens = sentence.split()
        output_label = []

        for i, token in enumerate(tokens):
            prob = random.random()
            if prob < 0.15:
                prob /= 0.15

                # 80% randomly change token to mask token
                if prob < 0.8:
                    tokens[i] = self.vocab.mask_index

                # 10% randomly change token to random token
                elif prob < 0.9:
                    tokens[i] = random.randrange(len(self.vocab))

                # 10% randomly change token to current token
                else:
                    tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)

                output_label.append(self.vocab.stoi.get(token, self.vocab.unk_index))

            else:
                tokens[i] = self.vocab.stoi.get(token, self.vocab.unk_index)
                output_label.append(0)

        return tokens, output_label
    
    def __getitem__(self, item):
        t1, t2, is_next_label = self.random_sent(item)
        t1_random, t1_label = self.random_word(t1)
        t2_random, t2_label = self.random_word(t2)

        # [CLS] tag = SOS tag, [SEP] tag = EOS tag
        t1 = [self.vocab.sos_index] + t1_random + [self.vocab.eos_index]
        t2 = t2_random + [self.vocab.eos_index]

        t1_label = [self.vocab.pad_index] + t1_label + [self.vocab.pad_index]
        t2_label = t2_label + [self.vocab.pad_index]

        segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_label + t2_label)[:self.seq_len]

        padding = [self.vocab.pad_index for _ in range(self.seq_len - len(bert_input))]
        bert_input.extend(padding), bert_label.extend(padding), segment_label.extend(padding)

        output = {"bert_input": bert_input, "bert_label": bert_label,
                  "segment_label": segment_label, "is_next": is_next_label}
        
        return {key: torch.tensor(value) for key, value in output.items()}
        
    def __len__(self):
        return self.corpus_lines


In [3]:
class Vocab(object):
    def __init__(self, counter, specials, max_size = None, min_freq = 1):
        self.freqs = counter
        self.itos = list(specials)

        self.pad_index = 0
        self.unk_index = 1
        self.eos_index = 2
        self.sos_index = 3
        self.mask_index = 4

        for token in specials:
            del counter[token]

        words_and_freqs = sorted(counter.items(), key = lambda tup: tup[0])
        words_and_freqs.sort(key = lambda tup: tup[1], reverse=True)

        for word, freq in words_and_freqs:
            if freq < min_freq or len(self.itos) == max_size:
                break
            self.itos.append(word)

        self.stoi = {token: i for i, token in enumerate(self.itos)}

    def __len__(self):
        return len(self.itos)

In [4]:
counter = Counter()
with open('./data/eng-fra.txt', "r", encoding="utf-8") as f:
    for line in tqdm.tqdm(f, desc="Loading Dataset"):
        if isinstance(line, list):
            words = line
        else:
            words = line.replace("\n", "").replace("\t", " ").split()

        for word in words:
            counter[word] += 1

Loading Dataset: 135842it [00:00, 263929.66it/s]


In [5]:
vocab = Vocab(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"])

In [8]:
vocab.stoi

{'<pad>': 0,
 '<unk>': 1,
 '<eos>': 2,
 '<sos>': 3,
 '<mask>': 4,
 'I': 5,
 'to': 6,
 'de': 7,
 'Je': 8,
 'a': 9,
 'you': 10,
 'the': 11,
 '?': 12,
 'pas': 13,
 'que': 14,
 'à': 15,
 'Tom': 16,
 'ne': 17,
 'la': 18,
 'le': 19,
 'Il': 20,
 'is': 21,
 'me': 22,
 'vous': 23,
 'est': 24,
 'of': 25,
 'un': 26,
 'He': 27,
 'in': 28,
 'ce': 29,
 'en': 30,
 'have': 31,
 'was': 32,
 'une': 33,
 'for': 34,
 'je': 35,
 'your': 36,
 'that': 37,
 'pour': 38,
 'suis': 39,
 'You': 40,
 "don't": 41,
 "I'm": 42,
 'les': 43,
 "J'ai": 44,
 '!': 45,
 'be': 46,
 'not': 47,
 'The': 48,
 'want': 49,
 'my': 50,
 'Elle': 51,
 'do': 52,
 'She': 53,
 'Nous': 54,
 'tu': 55,
 'Vous': 56,
 'Tu': 57,
 'this': 58,
 'like': 59,
 'on': 60,
 'it': 61,
 'with': 62,
 'are': 63,
 'dans': 64,
 'des': 65,
 'you.': 66,
 "C'est": 67,
 'nous': 68,
 'We': 69,
 'know': 70,
 'plus': 71,
 'te': 72,
 'faire': 73,
 'at': 74,
 'what': 75,
 'se': 76,
 'du': 77,
 'his': 78,
 'as': 79,
 'avec': 80,
 'veux': 81,
 'au': 82,
 'all': 83,
 'q

In [11]:
vocab.stoi.get('', 1)

6

In [42]:
train_dataset = BERTDataset(corpus_path = './data/eng-fra.txt', vocab = vocab , seq_len = 20)

Loading Dataset: 135842it [00:00, 273806.93it/s]


In [43]:
train_data_loader = DataLoader(train_dataset, batch_size = 64)

## Model

### Embedding

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]
    
class SegmentEmbedding(nn.Embedding):
    def __init__(self, embed_size=512):
        super().__init__(3, embed_size, padding_idx=0)

class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=512):
        super().__init__(vocab_size, embed_size, padding_idx=0)

In [None]:
class BERTEmbedding(nn.Module):
    """
    BERT Embedding which is consisted with under features
        1. TokenEmbedding : normal embedding matrix
        2. PositionalEmbedding : adding positional information using sin, cos
        2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2)

        sum of all these features are output of BERTEmbedding
    """

    def __init__(self, vocab_size, embed_size, dropout=0.1):
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
        self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
        self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size

    def forward(self, sequence, segment_label):
        x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
        return self.dropout(x)

### Bert

In [None]:
class BERT(nn.Module):
    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads

        # paper noted they used 4*hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = hidden * 4

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)

        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden, nhead=attn_heads, 
                                                   dim_feedforward = hidden * 4, dropout = dropout)
        
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

    def forward(self, x, segment_info):
        # attention masking for padded token
        # torch.ByteTensor([batch_size, 1, seq_len, seq_len)
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x, segment_info)

        x = self.transformer_encoder(x, mask)
        return x

In [None]:
class NextSentencePrediction(nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x[:, 0]))


class MaskedLanguageModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))


class BERTLM(nn.Module):
    """
    BERT Language Model
    Next Sentence Prediction Model + Masked Language Model
    """

    def __init__(self, bert: BERT, vocab_size):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.hidden)
        self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)

    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return self.next_sentence(x), self.mask_lm(x)


## Train

In [44]:
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

In [45]:
def train(model, train_data_loader, criterion, optim_schedule):
    for i, data in enumerate(train_data_loader):
        data = {key:value for key, value in data.items()}
        next_sentence_out, mask_lm_out = model.forward(data["bert_input"], data["segment_label"])

        next_loss = criterion(next_sentence_out, data["is_next"])

        mask_loss = criterion(mask_lm_out.transpose(1, 2), data["bert_label"])

        loss = next_loss + mask_loss

        optim_schedule.zero_grad()
        loss.backward()
        optim_schedule.step_and_update_lr()


In [None]:
criterion = nn.NLLLoss(ignore_index = 0)
bert_hidden = 768
warmup_steps = 10000
lr = 1e-4
betas = (0.9, 0.999)
weight_decay = 0.01

bert = BERT(len(vocab))
model = BERTLM(bert, len(vocab))

optim = Adam(model.parameters(), lr = lr, betas = betas, weight_decay = weight_decay)
optim_schedule = ScheduledOptim(optim, bert.hidden, n_warmup_steps=warmup_steps)

In [None]:
train(model = model, 
      train_data_loader = train_data_loader, 
      criterion = criterion, 
      optim_schedule = optim_schedule)