**遗留问题**
* segment的padding怎么ignore
* wordpiece 怎么用
* torch.gather怎么用

In [66]:
import re
import math
from random import random, randrange, randint, shuffle

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

from utils import EncoderLayer
from utils import pad_mask

## 准备数据集

In [2]:
text = (
    'Hello, how are you? I am Romeo.\n' # R
    'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
    'Nice meet you too. How are you today?\n' # R
    'Great. My baseball team won the competition.\n' # J
    'Oh Congratulations, Juliet\n' # R
    'Thank you Romeo\n' # J
    'Where are you going today?\n' # R
    'I am going shopping. What about you?\n' # J
    'I am going to visit my grandmother. she is not very well' # R
)

In [3]:
# clean data
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n') # filter '.', ',', '?', '!'
word_list = list(set(" ".join(sentences).split())) # ['hello', 'how', 'are', 'you',...]
word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
for i, w in enumerate(word_list):
    word2idx[w] = i + 4
idx2word = {i: w for i, w in enumerate(word2idx)}
vocab_size = len(word2idx)

In [4]:
# sentence to idx
token_list = list()
for sentence in sentences:
    arr = [word2idx[s] for s in sentence.split()]
    token_list.append(arr)

In [5]:
token_list

[[32, 15, 25, 14, 22, 21, 37],
 [32, 37, 38, 31, 27, 24, 4, 19, 12, 14],
 [4, 12, 14, 18, 15, 25, 14, 16],
 [11, 38, 10, 17, 28, 5, 29],
 [6, 13, 24],
 [20, 14, 37],
 [7, 25, 14, 34, 16],
 [22, 21, 34, 35, 30, 33, 14],
 [22, 21, 34, 19, 36, 38, 39, 8, 27, 23, 26, 9]]

## 数据预处理
根据概率随机mask或者替换一句话中15%的token，还需要拼接任意两句话

In [6]:
batch_size = 6
max_len = 30
max_pred = 5

In [30]:
# padding
def zero_padding(input_ids, segment_ids, max_len, max_pred, n_pred, masked_pos, masked_tokens):
    n_pad = max_len - len(input_ids)
    input_ids.extend([0] * n_pad)
    segment_ids.extend([0] * n_pad)
    
    # zero padding tokens
    if max_pred > n_pred:
        n_pad = max_pred - n_pred
        masked_tokens.extend([0] * n_pad)
        masked_pos.extend([0] * n_pad)
    return input_ids, segment_ids, masked_pos, masked_tokens

In [8]:
# 随机MASK
def mask_lm(input_ids, max_pred):
    # 单句要预测的token个数(15%)
    n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))
    # 候选mask id 列表， 特殊标记不可mask
    cand_maked_pos = [
        i for i, token in enumerate(input_ids)
        if token != word2idx['[CLS]'] and token != word2idx['[SEP]']
    ]
    shuffle(cand_maked_pos)
    masked_tokens, masked_pos = [], []
    for pos in cand_maked_pos[:n_pred]:
        masked_pos.append(pos)
        masked_tokens.append(input_ids[pos])
        if random() < 0.8:  # 80% 替换成[MASK]
            input_ids[pos] = word2idx['[MASK]']
        elif random() > 0.9:  # 10% 替换成任意词
            index = randint(0, vocab_size - 1)
            while index < 4:  # 特殊标记不可替换
                index = randint(0, vocab_size - 1)
            input_ids[pos] = index
    return masked_pos, masked_tokens, n_pred

In [9]:
# 抽取positive和negative样本, 比例为1:1
def batch_sampler(batch_size, token_list, max_pred, max_len):
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        # randrange(stop):返回一个随机数
        # 如果tokens_a_index + 1 = tokens_b_index, 则为positive，否则 negative
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
        
        # mask
        masked_pos, masked_tokens, n_pred = mask_lm(input_ids, max_pred)
        # padding
        input_ids, segment_ids, masked_pos, masked_tokens = \
            zero_padding(input_ids, segment_ids, max_len, max_pred, n_pred, masked_pos, masked_tokens)
        
        if tokens_a_index + 1== tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_pos, masked_tokens, True])
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_pos, masked_tokens, False])
            negative += 1
    return batch

In [17]:
class MyDataSet(Dataset):
    def __init__(self, input_ids, segment_ids, masked_pos, masked_tokens, isNext):
        self.input_ids = input_ids
        self.segment_ids = segment_ids
        self.masked_tokens = masked_tokens
        self.masked_pos = masked_pos
        self.isNext = isNext
    
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], \
    self.masked_pos[idx], self.isNext[idx]

In [31]:
batch = batch_sampler(batch_size, token_list, max_pred, max_len)
input_ids, segment_ids, masked_pos, masked_tokens, isNext = zip(*batch)

In [32]:
input_ids, segment_ids, masked_pos, masked_tokens, isNext = \
    torch.LongTensor(input_ids),  torch.LongTensor(segment_ids), torch.LongTensor(masked_pos),\
    torch.LongTensor(masked_tokens), torch.LongTensor(isNext)

In [34]:
loader = DataLoader(MyDataSet(input_ids, segment_ids, masked_pos, masked_tokens, isNext), batch_size, True)

## 构建BERT模型

In [72]:
vocab_size = vocab_size
batch_size = 6
embedding_size = 768
hidden_size = 768 * 4
dim_k = dim_v = 64
n_heads = 12
n_layers = 6
max_len = 30
n_segments = 2
lr = 1e-4

In [35]:
def gelu(x):
    """
      Implementation of the gelu activation function.
      For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
      0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
      Also see https://arxiv.org/abs/1606.08415
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

In [36]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embedding_size, max_length, n_segments):
        super(Embedding, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, embedding_size)
        self.postion_embedding = nn.Embedding(max_length, embedding_size)
        self.segment_embedding = nn.Embedding(n_segments, embedding_size)
        self.norm = nn.LayerNorm(embedding_size)
    
    def forward(self, X, seg):
        seq_len = X.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(X)
        embedding = self.token_embedding(X) + self.postion_embedding(pos) + self.segment_embedding(seg)
        embedding = self.norm(embedding)
        return embedding

In [63]:
class BERT(nn.Module):
    def __init__(self, vocab_size, embedding_size, k_dim, v_dim, n_heads, hidden_size, n_layers, \
                 max_length, n_segments, \
                 dropout=0.1):
        super(BERT, self).__init__()
        self.embedding = Embedding(vocab_size, embedding_size, max_length, n_segments)
        self.layers = nn.ModuleList()
        for _ in range(n_layers):
            self.layers.append(EncoderLayer(embedding_size, k_dim, v_dim, n_heads, hidden_size))
        # task1
        # shared weight with Token Embedding Layer
        self.linear = nn.Linear(embedding_size, embedding_size)
        self.fc_mlm=nn.Linear(embedding_size, vocab_size, bias=False)
        embedding_weight = self.embedding.token_embedding.weight
        self.fc_mlm.weight = embedding_weight
        self.activate = gelu
        # task2
        self.fc_task_nps = nn.Sequential(
            nn.Linear(embedding_size, embedding_size),
            nn.Dropout(dropout),
            nn.Tanh(),
            nn.Linear(embedding_size, 2)
        )
        
    def forward(self, input_ids, segment_ids, masked_pos):
        padding_mask = pad_mask(input_ids)
        X = self.embedding(input_ids, segment_ids)
        for layer in self.layers:
            X = layer(X, padding_mask)
            
        # task1: MLM
        masked_pos = masked_pos[:, :, None].expand(-1, -1, X.size(2))
        h_masked = torch.gather(X, 1, masked_pos)
        h_masked = self.activate(self.linear(h_masked))
        result_mlm = self.fc_mlm(h_masked)
        # task2: predict isNext by first token(CLS)
        reslult_isNext = self.fc_task_nps(X[:, 0])
        return result_mlm, reslult_isNext

In [74]:
loss = nn.CrossEntropyLoss()
model = BERT(vocab_size, embedding_size, dim_k, dim_v, n_heads, hidden_size, n_layers, max_len, n_segments)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [77]:
for epoch in range(180):
    output_mlm, output_isNext = model(input_ids, segment_ids, masked_pos)
    loss_mlm = loss(output_mlm.transpose(1, 2), masked_pos)
    loss_mlm = (loss_mlm.float()).mean()
    loss_isNext = loss(output_isNext, isNext)
    l = loss_mlm + loss_isNext
    optimizer.zero_grad()
    l.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(l))

Epoch: 0010 loss = 0.619410
Epoch: 0020 loss = 0.594902
Epoch: 0030 loss = 0.442325
Epoch: 0040 loss = 0.274707
Epoch: 0050 loss = 0.066766
Epoch: 0060 loss = 0.005501
Epoch: 0070 loss = 0.000794
Epoch: 0080 loss = 0.000278
Epoch: 0090 loss = 0.000183
Epoch: 0100 loss = 0.000146
Epoch: 0110 loss = 0.000122
Epoch: 0120 loss = 0.000094
Epoch: 0130 loss = 0.000082
Epoch: 0140 loss = 0.000075
Epoch: 0150 loss = 0.000049
Epoch: 0160 loss = 0.000037
Epoch: 0170 loss = 0.000026
Epoch: 0180 loss = 0.000018
