In [1]:
import sys

In [2]:
sys.path.insert(1, '../utils/')

In [3]:
import re
from model import *
from config import *

In [4]:
import math
import re
from random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [5]:
text = (
       'Hello, how are you? I am Romeo.\n'
       'Hello, Romeo My name is Juliet. Nice to meet you.\n'
       'Nice meet you too. How are you today?\n'
       'Great. My baseball team won the competition.\n'
       'Oh Congratulations, Juliet\n'
       'Thanks you Romeo'
   )

In [6]:
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')  # filter '.', ',', '?', '!'
word_list = list(set(" ".join(sentences).split()))

In [7]:
word_dict = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
for i, w in enumerate(word_list):
    
    word_dict[w] = i + 4
number_dict = {i: w for i, w in enumerate(word_dict)}
vocab_size = len(word_dict)
    
    
token_list = list()
for sentence in sentences:
    arr = [word_dict[s] for s in sentence.split()]
    token_list.append(arr)

In [8]:
def make_batch(config):
    batch_size = config['batch_size']
    maxlen = config['maxlen']
    max_pred = config['max_pred']
    
    
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences)) 

        tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index]
 
        input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]


        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        # MASK LM
        n_pred =  min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) # 15 % of tokens in one sentence
        cand_maked_pos = [i for i, token in enumerate(input_ids)
                         if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:  # 80%
                input_ids[pos] = word_dict['[MASK]'] # make mask
            elif random() < 0.5:  # 10%
                index = randint(0, vocab_size - 1) # random index in vocabulary
                input_ids[pos] = word_dict[number_dict[index]] # replace

       # Zero Paddings
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

        # Zero Padding (100% - 15%) tokens
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
            negative += 1
    return batch

In [9]:
config = {
    'wordvocab_size':vocab_size,
    'maxlen': 64,
    'batch_size': 32,
    'max_pred': 5, 
    'n_layers': 12,
    'n_heads': 12,
    'd_model': 768,
    'd_ff': 768 * 4,
    'd_k': 64,
    'n_segments': 2
}

In [10]:
config['maxlen']

64

In [11]:
config

{'wordvocab_size': 29,
 'maxlen': 64,
 'batch_size': 32,
 'max_pred': 5,
 'n_layers': 12,
 'n_heads': 12,
 'd_model': 768,
 'd_ff': 3072,
 'd_k': 64,
 'n_segments': 2}

In [12]:
model = BERT(config = config)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batch = make_batch(config=config)
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

In [13]:
input_ids

tensor([[ 1,  3, 15,  ...,  0,  0,  0],
        [ 1, 20, 14,  ...,  0,  0,  0],
        [ 1, 20, 14,  ...,  0,  0,  0],
        ...,
        [ 1, 16,  8,  ...,  0,  0,  0],
        [ 1,  3, 22,  ...,  0,  0,  0],
        [ 1, 16,  8,  ...,  0,  0,  0]])

In [16]:
masked_pos

tensor([[14,  1,  4,  0,  0],
        [ 8,  7,  0,  0,  0],
        [ 3,  8,  0,  0,  0],
        [ 3,  0,  0,  0,  0],
        [11, 14, 16,  0,  0],
        [10,  8,  0,  0,  0],
        [10,  7, 11,  0,  0],
        [15, 18,  4,  0,  0],
        [18, 10,  7,  0,  0],
        [ 6, 10,  0,  0,  0],
        [ 1,  3,  0,  0,  0],
        [12,  5,  6,  0,  0],
        [ 7,  5,  0,  0,  0],
        [ 3, 15,  7,  0,  0],
        [ 7,  0,  0,  0,  0],
        [10,  9, 12,  0,  0],
        [ 6,  2,  9,  0,  0],
        [16,  2,  7,  0,  0],
        [16,  5,  3,  0,  0],
        [ 2, 12,  6,  0,  0],
        [ 6,  0,  0,  0,  0],
        [10,  3,  0,  0,  0],
        [11,  4,  0,  0,  0],
        [ 1, 10, 14,  0,  0],
        [ 2,  0,  0,  0,  0],
        [ 7,  6, 14,  0,  0],
        [ 7,  3,  5,  0,  0],
        [ 6,  2,  0,  0,  0],
        [ 7,  0,  0,  0,  0],
        [13, 17, 18,  0,  0],
        [14, 15,  1,  0,  0],
        [ 3,  7,  4,  0,  0]])

In [47]:
for epoch in range(100):
    optimizer.zero_grad()
    logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
    loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
    loss_lm = (loss_lm.float()).mean()
    loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
    loss = loss_lm + loss_clsf
    if (epoch + 1) % 10 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

Epoch: 0010 cost = 2.512177
Epoch: 0020 cost = 0.176875
Epoch: 0030 cost = 0.459661
Epoch: 0040 cost = 1.172263
Epoch: 0050 cost = 228.027298
Epoch: 0060 cost = 1.047376
Epoch: 0070 cost = 0.179516
Epoch: 0080 cost = 11.528533
Epoch: 0090 cost = 4.905905
Epoch: 0100 cost = 2.647542


In [48]:
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[0]))
print(text)
print()
print([number_dict[w.item()] for w in input_ids[0] if number_dict[w.item()] != '[PAD]'])

logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('masked tokens list : ',[pos.item() for pos in masked_tokens[0] if pos.item() != 0])
print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])

logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_clsf else False)

Hello, how are you? I am Romeo.
Hello, Romeo My name is Juliet. Nice to meet you.
Nice meet you too. How are you today?
Great. My baseball team won the competition.
Oh Congratulations, Juliet
Thanks you Romeo

['[CLS]', '[MASK]', 'you', 'romeo', '[SEP]', 'hello', 'how', 'are', 'you', 'i', 'baseball', 'romeo', '[SEP]']
masked tokens list :  [22, 12]
predict masked tokens list :  [22, 22]
isNext :  False
predict isNext :  False
