In [1]:
import re
from random import randrange, shuffle, random, randint

import torch

## 准备数据集

In [2]:
 text = (
    'Hello, how are you? I am Romeo.\n' # R
    'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
    'Nice meet you too. How are you today?\n' # R
    'Great. My baseball team won the competition.\n' # J
    'Oh Congratulations, Juliet\n' # R
    'Thank you Romeo\n' # J
    'Where are you going today?\n' # R
    'I am going shopping. What about you?\n' # J
    'I am going to visit my grandmother. she is not very well' # R
)

In [3]:
def clean_text(text):
    # clean data
    sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n') # filter '.', ',', '?', '!'
    word_list = list(set(" ".join(sentences).split())) # ['hello', 'how', 'are', 'you',...]
    word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
    for i, w in enumerate(word_list):
        word2idx[w] = i + 4
    idx2word = {i: w for i, w in enumerate(word2idx)}
    vocab_size = len(word2idx)
    return word2idx, idx2word, vocab_size, word_list, sentences

In [4]:
def text2token(sentences):
    token_list = list()
    for sentence in sentences:
        arr = [word2idx[s] for s in sentence.split()]
        token_list.append(arr)
    return token_list

In [5]:
word2idx, idx2word, vocab_size, word_list, sentences = clean_text(text)

In [6]:
token_list = text2token(sentences)

## BERT预处理数据

In [7]:
batch_size = 6
max_len = 30
max_pred = 5

In [17]:
# padding
def zero_padding(input_ids, segment_ids, max_len, max_pred, n_pred, masked_pos, masked_tokens):
    n_pad = max_len - len(input_ids)
    input_ids.extend([0] * n_pad)
    segment_ids.extend([0] * n_pad)
    
    # zero padding tokens
    if max_pred > n_pred:
        masked_tokens.extend([0] * n_pad)
        masked_pos.extend([0] * n_pad)
    return input_ids, segment_ids, masked_pos, masked_tokens

In [12]:
# 随机MASK
def mask_lm(input_ids, max_pred):
    # 单句要预测的token个数(15%)
    n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))
    # 候选mask id 列表， 特殊标记不可mask
    cand_maked_pos = [
        i for i, token in enumerate(input_ids)
        if token != word2idx['[CLS]'] and token != word2idx['[SEP]']
    ]
    shuffle(cand_maked_pos)
    masked_tokens, masked_pos = [], []
    for pos in cand_maked_pos[:n_pred]:
        masked_pos.append(pos)
        masked_tokens.append(input_ids[pos])
        if random() < 0.8:  # 80% 替换成[MASK]
            input_ids[pos] = word2idx['[MASK]']
        elif random() > 0.9:  # 10% 替换成任意词
            index = randint(0, vocab_size - 1)
            while index < 4:  # 特殊标记不可替换
                index = randint(0, vocab_size - 1)
            input_ids[pos] = index
    return masked_pos, masked_tokens, n_pred

In [20]:
# 抽取positive和negative样本, 比例为1:1
def batch_sampler(batch_size, token_list, max_pred, max_len):
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        # randrange(stop):返回一个随机数
        # 如果tokens_a_index + 1 = tokens_b_index, 则为positive，否则 negative
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
        
        # mask
        masked_pos, masked_tokens, n_pred = mask_lm(input_ids, max_pred)
        # padding
        input_ids, segment_ids, masked_pos, masked_tokens = \
            zero_padding(input_ids, segment_ids, max_len, max_pred, n_pred, masked_pos, masked_tokens)
        
        if tokens_a_index + 1== tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_pos, masked_tokens, True])
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_pos, masked_tokens, False])
            negative += 1
    return batch

In [21]:
batch = batch_sampler(batch_size, token_list, max_pred, max_len)

In [22]:
input_ids = batch[0][0]

In [23]:
input_ids

[1,
 38,
 39,
 12,
 17,
 34,
 8,
 3,
 2,
 19,
 18,
 9,
 2,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [None]:
# 1. 抽取positive和negative样本, 比例为1:1
# 2. 随机MASK
# 3. padding
def make_data():
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
        # MASK LM
        n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))
        cand_maked_pos = [i for i, token in enumerate(input_ids) 
                          if token != word2idx['[CLS]'] and token != word2idx['[SEP]']]
        
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:
                input_ids[pos] = word2idx['[MASK]']
            elif random() > 0.9:
                index = randint(0, vocab_size - 1)
                while index < 4:
                    index = randint(0, vocab_size - 1)
                input_ids[pos] = index
        # zero padding
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)
        
        # zero padding (100%-15%) tokens
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)
            
        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
            negative += 1
    return batch

In [None]:
# 1. 抽取positive和negative样本, 比例为1:1
# 2. 随机MASK
# 3. padding
def make_data():
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences))
        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)
        # MASK LM
        n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))
        cand_maked_pos = [i for i, token in enumerate(input_ids) 
                          if token != word2idx['[CLS]'] and token != word2idx['[SEP]']]
        
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:
                input_ids[pos] = word2idx['[MASK]']
            elif random() > 0.9:
                index = randint(0, vocab_size - 1)
                while index < 4:
                    index = randint(0, vocab_size - 1)
                input_ids[pos] = index
        # zero padding
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)
        
        # zero padding (100%-15%) tokens
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)
            
        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size / 2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
            negative += 1
    return batch

In [None]:
batch = make_data()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)

In [None]:
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
    torch.LongTensor(input_ids),  torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens),\
    torch.LongTensor(masked_pos), torch.LongTensor(isNext)

In [None]:
input_ids, input_ids.shape

In [None]:
segment_ids, segment_ids.shape

In [None]:
masked_pos, masked_pos.shape