In [None]:
import re
import math
import torch
import numpy as np
from random import *
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

In [None]:
text = (
    'Hello, how are you? I am Romeo.\n' # R
    'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
    'Nice meet you too. How are you today?\n' # R
    'Great. My baseball team won the competition.\n' # J
    'Oh Congratulations, Juliet\n' # R
    'Thank you Romeo\n' # J
    'Where are you going today?\n' # R
    'I am going shopping. What about you?\n' # J
    'I am going to visit my grandmother. she is not very well' # R
)
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')
word_list = list(set(" ".join(sentences).split())) 
word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
for i, w in enumerate(word_list):
    word2idx[w] = i + 4
idx2word = {i: w for i, w in enumerate(word2idx)}
vocab_size = len(word2idx)

token_list = []
for sentence in sentences:
    arr = [word2idx[s] for s in sentence.split()]
    token_list.append(arr)

In [None]:
maxlen = 30
batch_size = 6
max_pred = 5 
n_layers = 6
n_heads = 12
d_model = 768
d_ff = 768*4 
d_k = d_v = 64  
n_segments = 2

In [None]:
def make_data():
  batch = []
  positive_candis = np.random
  for idx in range(batch_size):
    # 正样本，连续两个句子拼接
    if idx%2 == 0:
      aid = randrange(len(sentences) - 1)
      bid = aid + 1
    # 负样本，随机两个句子拼接
    else:
      aid, bid = randrange(len(sentences)), randrange(len(sentences))
    tokena, tokenb = token_list[aid], token_list[bid]
    input_ids = [word2idx['[CLS]']] + tokena + [word2idx['[SEP]']] + tokenb + [word2idx['[SEP]']]
    segment_ids = [0] * (len(tokena) + 2) + [1] * (len(tokenb) + 1)

    # 特殊处理15%的单词
    n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))
    masked_candis = [i for i, t in enumerate(input_ids) if idx2word[t] not in ['[CLS]', '[SEP]']]
    shuffle(masked_candis)
    masked_pos = masked_candis[: n_pred]
    masked_tokens = [input_ids[i] for i in masked_pos]
    # print(masked_candis)
    # print(masked_pos, masked_tokens)
    for i in masked_pos:
      r = random()
      if r < 0.8:
        input_ids[i] = word2idx['[MASK]']
      elif r > 0.9:
        # 随机选取一个单词替换，注意不能用标识符替换
        ri = randrange(4, vocab_size - 1)
        input_ids[i] = ri
    # 0 paddding。分为对整个样本的padding以及遮挡部分未满max_pred的padding
    n1 = maxlen - len(input_ids)
    input_ids += [0] * n1
    segment_ids += [0] * n1

    n2 = max_pred - n_pred
    masked_tokens += [word2idx['[PAD]']] * n2
    masked_pos += [0] * n2
    if idx%2 == 0:
      batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
    else:
      batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
    
  return batch

def f(x):
  return torch.LongTensor(x)
batch = make_data()
input_ids, segment_ids, masked_tokens, masked_pos, isnext = zip(*batch)
input_ids = f(input_ids)
segment_ids = f(segment_ids)
masked_tokens = f(masked_tokens)
masked_pos = f(masked_pos)
isnext = f(isnext)

In [None]:
class MyDataSet(Data.Dataset):
  def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isnext):
    self.input_ids = input_ids
    self.segment_ids = segment_ids
    self.masked_tokens = masked_tokens
    self.masked_pos = masked_pos
    self.isnext = isnext
  def __len__(self):
    return len(self.input_ids)
  def __getitem__(self, idx):
    return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isnext[idx]

loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isnext), batch_size, True)

In [None]:
def get_attn_pad_mask(seq_q, seq_k):
  len_q = seq_q.shape[1]
  mask = seq_k.data.eq(0).unsqueeze(1)
  return mask.expand(-1, len_q, -1)

In [None]:
tmp = torch.tensor([0, 1, 0]).unsqueeze(0)
get_attn_pad_mask(tmp, tmp)

tensor([[[ True, False,  True],
         [ True, False,  True],
         [ True, False,  True]]])

In [None]:
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

In [None]:
class Embedding(nn.Module):
  def __init__(self):
    super(Embedding, self).__init__()
    self.pos_embedding = nn.Embedding(maxlen, d_model)
    self.tok_embedding = nn.Embedding(vocab_size, d_model)
    self.seg_embedding = nn.Embedding(n_segments, d_model)
    self.norm = nn.LayerNorm(d_model)
  def forward(self, x, seg):
    seq_len = x.shape[1]
    pos = torch.arange(seq_len, dtype = torch.long)
    pos = pos.unsqueeze(0).expand_as(x)
    emb = self.pos_embedding(pos) + self.tok_embedding(x) + self.seg_embedding(seg)
    return self.norm(emb)


In [None]:
test = Embedding()
emb = test(input_ids, segment_ids)
emb.shape

torch.Size([6, 30, 768])

In [None]:
class ScaledDotProductAttention(nn.Module):
  def __init__(self):
    super(ScaledDotProductAttention, self).__init__()
  def forward(self, q, k, v, mask):
    scores = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(d_k)
    scores.masked_fill_(mask, -1e9)
    attn = nn.Softmax(dim=-1)(scores)
    context = torch.matmul(attn, v)
    return context, attn

In [None]:
test = ScaledDotProductAttention()
tmp = torch.randn(2, n_heads, 4, 5)
sq = sk = torch.tensor([[1, 1, 0, 0],[2 , 3, 0, 0]])

mask = get_attn_pad_mask(sq, sk)
m = mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
c, a = test(tmp, tmp, tmp, m)

print(c.shape, a.shape)
print(a[0, 0])

torch.Size([2, 12, 4, 5]) torch.Size([2, 12, 4, 4])
tensor([[0.5690, 0.4310, 0.0000, 0.0000],
        [0.4828, 0.5172, 0.0000, 0.0000],
        [0.5154, 0.4846, 0.0000, 0.0000],
        [0.4784, 0.5216, 0.0000, 0.0000]])


In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self):
    super(MultiHeadAttention, self).__init__()
    self.wq = nn.Linear(d_model, n_heads * d_k)
    self.wk = nn.Linear(d_model, n_heads * d_k)
    self.wv = nn.Linear(d_model, n_heads * d_v)
    self.dense = nn.Linear(n_heads * d_v, d_model)
  def forward(self, qx, kx, vx, mask):
    # 不能直接用全局变量batch_size,有的batch可能不满足
    bs = qx.shape[0]
    # 第一维代表seq_len
    q = self.wq(qx).view(bs, -1, n_heads, d_k).transpose(1, 2)
    k = self.wk(kx).view(bs, -1, n_heads, d_k).transpose(1, 2)
    v = self.wv(vx).view(bs, -1, n_heads, d_v).transpose(1, 2)
    mask = mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
    # 注意scaleddotproduct是module类，而不是函数
    context, _ = ScaledDotProductAttention()(q, k, v, mask)
    context = context.transpose(1, 2).contiguous().view(bs, -1, n_heads * d_v)
    output = self.dense(context)
    return nn.LayerNorm(d_model)(output + qx)

In [None]:
# 2是sq,sk的batch_size, 4是sq,sk的序列长度
# 这里测试的tmp是没有n_heads维度的
tmp = torch.rand(2, 4, d_model)
sq = sk = torch.tensor([[1, 1, 0, 0],[2 , 3, 0, 0]])
mask = get_attn_pad_mask(sq, sk)
test = MultiHeadAttention()
ot = test(tmp, tmp, tmp, mask)
print(ot.shape)

torch.Size([2, 4, 768])


In [None]:
class FeedForward(nn.Module):
  def __init__(self):
    super(FeedForward, self).__init__()
    self.dense1 = nn.Linear(d_model, d_ff)
    self.dense2 = nn.Linear(d_ff, d_model)
  def forward(self, x):
    return self.dense2(gelu(self.dense1(x)))

In [None]:
test = FeedForward()
test(ot).shape

torch.Size([2, 4, 768])

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self):
    super(EncoderLayer, self).__init__()
    self.attention = MultiHeadAttention()
    self.feedforward = FeedForward()
  def forward(self, inputs, mask):
    outputs = self.attention(inputs, inputs, inputs, mask)
    outputs = self.feedforward(outputs)
    return outputs

In [None]:
test = EncoderLayer()
tmp = torch.rand(2, 4, d_model)
mask = get_attn_pad_mask(sq, sk)
test(tmp, mask).shape

torch.Size([2, 4, 768])

In [None]:
ln0 = nn.Linear(3, 2)
print(ln0.weight.shape)
ln = nn.Linear(2, 3)
# ln.weight = ln0.weight
ln(torch.tensor([1.0, 1.0]))

torch.Size([2, 3])


tensor([-0.8619,  0.5594, -0.4929], grad_fn=<AddBackward0>)

In [None]:
class BERT(nn.Module):
  def __init__(self):
    super(BERT, self).__init__()
    self.embedding = Embedding()
    self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
    # 句子匹配最后处理层
    self.matchds = nn.Sequential(nn.Linear(d_model, d_model),
                  nn.Dropout(0.5),
                  nn.Tanh())
    self.matchclf = nn.Linear(d_model, 2)
    # 完形填空最后处理层
    self.predds = nn.Linear(d_model, d_model)
    wt = self.embedding.tok_embedding.weight
    self.predclf = nn.Linear(d_model, vocab_size, bias=False)
    self.predclf.weight = wt
    self.active = gelu

  def forward(self, input_ids, segment_ids, masked_pos):
    outputs = self.embedding(input_ids, segment_ids)
    mask = get_attn_pad_mask(input_ids, input_ids)
    for layer in self.layers:
      outputs = layer(outputs, mask)
    # 句子匹配任务
    h_pooled = self.matchds(outputs[:, 0])
    matchclfs = self.matchclf(h_pooled)
    # 完型填空任务
    masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model)
    h_masked = torch.gather(outputs, 1, masked_pos)
    h_masked = self.active(self.predds(h_masked))
    predclfs = self.predclf(h_masked)
    return predclfs, matchclfs

In [None]:
model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters(), lr=0.001)
for epoch in range(180):
  for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
    logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
    loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
    loss_lm = (loss_lm.float()).mean()
    loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
    loss = loss_lm + loss_clsf
    if (epoch + 1) % 10 == 0:
      print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Epoch: 0010 loss = 0.868357
Epoch: 0020 loss = 0.786383
Epoch: 0030 loss = 0.733968
Epoch: 0040 loss = 0.728615
Epoch: 0050 loss = 0.744979
Epoch: 0060 loss = 0.697057
Epoch: 0070 loss = 0.673936
Epoch: 0080 loss = 0.664784
Epoch: 0090 loss = 0.633446
Epoch: 0100 loss = 0.646477
Epoch: 0110 loss = 0.600737
Epoch: 0120 loss = 0.524431
Epoch: 0130 loss = 0.460449
Epoch: 0140 loss = 0.512675
Epoch: 0150 loss = 0.446397
Epoch: 0160 loss = 0.476192
Epoch: 0170 loss = 0.421251
Epoch: 0180 loss = 0.287784


In [None]:
for idx in range(5):
  print('---------', idx)
  input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[idx]
  print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])

  logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), \
                  torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))
  logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
  print('masked pos :', masked_pos)
  print('masked tokens list : ',[idx2word[pos] for pos in masked_tokens if pos != 0])
  print('predict masked tokens list : ',[idx2word[pos] for pos in logits_lm if pos != 0])

  logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
  print('isNext : ', True if isNext else False)
  print('predict isNext : ',True if logits_clsf else False)

--------- 0
['[CLS]', 'nice', 'meet', 'team', 'too', 'how', 'are', 'you', 'today', '[SEP]', 'great', 'my', 'baseball', 'team', 'won', 'the', 'competition', '[SEP]']
masked pos : [13, 3, 0, 0, 0]
masked tokens list :  ['team', 'you']
predict masked tokens list :  ['team', 'you']
isNext :  True
predict isNext :  True
--------- 1
['[CLS]', 'hello', '[MASK]', 'my', 'name', 'is', 'juliet', 'nice', 'to', 'meet', 'you', '[SEP]', 'hello', 'romeo', 'my', 'name', 'is', 'juliet', '[MASK]', 'to', 'meet', 'you', '[SEP]']
masked pos : [13, 2, 18, 0, 0]
masked tokens list :  ['romeo', 'romeo', 'nice']
predict masked tokens list :  ['romeo', 'romeo', 'nice']
isNext :  False
predict isNext :  False
--------- 2
['[CLS]', 'hello', 'romeo', 'my', 'name', '[MASK]', 'juliet', 'nice', 'to', 'meet', 'you', '[SEP]', 'nice', '[MASK]', 'you', 'too', 'how', 'are', 'you', 'today', '[SEP]']
masked pos : [1, 13, 5, 0, 0]
masked tokens list :  ['hello', 'meet', 'is']
predict masked tokens list :  ['hello', 'meet', 'i