In [2]:
from bert import tokenization

In [3]:
tokenizer = tokenization.FullTokenizer(
      vocab_file='MASS.wordpiece', do_lower_case=False)




In [4]:
import random

rng = random.Random(12345)

In [5]:
max_seq_length = 300
dupe_factor = 5
short_seq_prob = 0.1
masked_lm_prob = 0.2
max_predictions_per_seq = 20
eos_id = 1

vocab_words = list(tokenizer.vocab.keys())

In [6]:
import collections
import tensorflow as tf
from tqdm import tqdm

maxlen = max_seq_length

def create_int_feature(values):
    feature = tf.train.Feature(
        int64_list = tf.train.Int64List(value = list(values))
    )
    return feature

def to_tfrecord(rows, filename):
    input_encoders, input_decoders, labels = [], [], []
    
    for i in tqdm(range(len(rows))):
        input_encoder, input_decoder, label = rows[0]
        input_encoder = tokenizer.convert_tokens_to_ids(input_encoder)
        input_decoder = tokenizer.convert_tokens_to_ids(input_decoder) + [eos_id]
        label = tokenizer.convert_tokens_to_ids(label) + [eos_id]
        input_encoder = input_encoder + [0] * (maxlen - len(input_encoder))
        input_decoder = input_decoder + [0] * (maxlen - len(input_decoder))
        label = label + [0] * (maxlen - len(label))
        input_encoders.append(input_encoder)
        input_decoders.append(input_decoder)
        labels.append(label)
    
    r = tf.python_io.TFRecordWriter(f'{filename}.tfrecord')
    for i in tqdm(range(len(labels))):
        features = collections.OrderedDict()
        features['input_encoder'] = create_int_feature(input_encoders[i])
        features['input_decoder'] = create_int_feature(input_decoders[i])
        features['y'] = create_int_feature(labels[i])
        tf_example = tf.train.Example(
            features = tf.train.Features(feature = features)
        )
        r.write(tf_example.SerializeToString())
    r.close()
    
def process_NoisyLanguagePairDataset(row):
    max_num_tokens = max_seq_length
    target_seq_length = max_num_tokens
    if rng.random() < short_seq_prob:
        target_seq_length = rng.randint(2, max_num_tokens)
    tokens_l = row[0][:max_num_tokens]
    tokens_r = row[1][:max_num_tokens]
    
    cand_indexes_l = []
    for (i, token) in enumerate(tokens_l):
        if token == '[CLS]' or token == '[SEP]':
            continue
        if (
            len(cand_indexes_l) >= 1
            and token.startswith('##')
        ):
            cand_indexes_l[-1].append(i)
        else:
            cand_indexes_l.append([i])
            
    cand_indexes_r = []
    for (i, token) in enumerate(tokens_r):
        if token == '[CLS]' or token == '[SEP]':
            continue
        if (
            len(cand_indexes_r) >= 1
            and token.startswith('##')
        ):
            cand_indexes_r[-1].append(i)
        else:
            cand_indexes_r.append([i])
            
    rng.shuffle(cand_indexes_l)
    rng.shuffle(cand_indexes_r)

    num_to_predict = min(
        max_predictions_per_seq,
        max(1, int(round(len(tokens_l) * masked_lm_prob))),
    )
    
    masked_lms = []
    output_tokens_l = list(tokens_l)
    covered_indexes = set()
    for index_set in cand_indexes_l:
        if len(masked_lms) >= num_to_predict:
            break
        is_any_index_covered = False
        for index in index_set:
            if index in covered_indexes:
                is_any_index_covered = True
                break
        if is_any_index_covered:
            continue
        for index in index_set:
            covered_indexes.add(index)

            masked_token = None
            # 80% of the time, replace with [MASK]
            if rng.random() < 0.8:
                masked_token = '[MASK]'
            else:
                # 10% of the time, keep original
                if rng.random() < 0.5:
                    masked_token = tokens_l[index]
                # 10% of the time, replace with random word
                else:
                    masked_token = vocab_words[
                        rng.randint(0, len(vocab_words) - 1)
                    ]

            output_tokens_l[index] = masked_token
            masked_lms.append(tokens_l[index])
    
    num_to_predict = min(
        max_predictions_per_seq,
        max(1, int(round(len(tokens_r) * masked_lm_prob))),
    )
    
    output_tokens_r = list(tokens_r)
    label, input_decoder = [], []
    covered_indexes = set()
    for index_set in cand_indexes_r:
        if len(label) >= num_to_predict:
            break
        is_any_index_covered = False
        for index in index_set:
            if index in covered_indexes:
                is_any_index_covered = True
                break
        if is_any_index_covered:
            continue
        for index in index_set:
            covered_indexes.add(index)

            masked_token = None
            # 80% of the time, replace with [MASK]
            if rng.random() < 0.8:
                masked_token = '[MASK]'
            else:
                # 10% of the time, keep original
                if rng.random() < 0.5:
                    masked_token = tokens_r[index]
                # 10% of the time, replace with random word
                else:
                    masked_token = vocab_words[
                        rng.randint(0, len(vocab_words) - 1)
                    ]

            output_tokens_r[index] = masked_token
            label.append(tokens_r[index])
            
    return output_tokens_l, output_tokens_r, tokens_r

In [8]:
dupe_factor = 15

def get_inputs(left, right, filename):
    all_documents = []
    for i in tqdm(range(len(left))):
        line_l = tokenization.convert_to_unicode(left[i])
        line_r = tokenization.convert_to_unicode(right[i])

        line_l = line_l.strip()
        line_r = line_r.strip()
        if line_l and line_r:
            tokens_l = tokenizer.tokenize(line_l)
            tokens_r = tokenizer.tokenize(line_r)
            if len(tokens_l) < max_seq_length and len(tokens_r) < max_seq_length:
                all_documents.append((tokens_l, tokens_r))
                
    results = []
    
    for row in tqdm(all_documents):
        for _ in range(dupe_factor):
            try:
                results.append(process_NoisyLanguagePairDataset(row))
            except:
                pass
            
    to_tfrecord(results, filename)

In [9]:
def chunks_multiple(l, n, filename):
    part = 0
    for i in range(0, len(l), n):
        x, y = list(zip(*l[i : i + n]))
        yield (x, y, f'{filename}-part-{part}')
        part += 1

In [10]:
pairs = ['en-ms.json', 'ms-en.json']

In [11]:
import json
import multi

batch_size = 500000

for pair in pairs:
    with open(pair) as fopen:
        data = json.load(fopen)
    
    part = 0
    for i in range(0, len(data['left']), batch_size):
        index = min(i + batch_size, len(data['left']))
        l = data['left'][i: index]
        r = data['right'][i: index]
        filename = f'{pair}-part-{part}'

        multi.multiprocessing(chunks_multiple(list(zip(l, r)), len(l) // 12, filename), get_inputs)
        part += 1

100%|██████████| 41666/41666 [00:48<00:00, 855.31it/s] 
100%|██████████| 41666/41666 [00:49<00:00, 842.32it/s]
100%|██████████| 41666/41666 [00:49<00:00, 834.05it/s]
100%|██████████| 41666/41666 [00:50<00:00, 830.35it/s]
100%|██████████| 41666/41666 [00:50<00:00, 826.49it/s]
100%|██████████| 41666/41666 [00:49<00:00, 838.19it/s]
100%|██████████| 41666/41666 [00:50<00:00, 830.95it/s]
100%|██████████| 41666/41666 [00:49<00:00, 835.99it/s]
100%|██████████| 41666/41666 [00:49<00:00, 833.86it/s]
100%|██████████| 41666/41666 [00:50<00:00, 827.85it/s]
100%|██████████| 41666/41666 [00:50<00:00, 821.55it/s]
100%|██████████| 41666/41666 [00:50<00:00, 821.19it/s]
100%|██████████| 41645/41645 [01:10<00:00, 591.55it/s]
100%|██████████| 41609/41609 [01:11<00:00, 577.96it/s]t/s]
100%|██████████| 41630/41630 [01:12<00:00, 574.71it/s]t/s]
 98%|█████████▊| 40836/41630 [01:11<00:01, 640.10it/s]s]s]
100%|██████████| 41612/41612 [01:12<00:00, 576.90it/s]s]]]
100%|██████████| 41627/41627 [01:12<00:00, 576.7

100%|██████████| 8/8 [00:00<00:00, 1799.55it/s]53.41it/s]
100%|██████████| 8/8 [00:00<00:00, 1289.51it/s]01.16it/s]
100%|██████████| 120/120 [00:00<00:00, 26834.96it/s]
100%|██████████| 120/120 [00:00<00:00, 3043.38it/s]3it/s]
100%|██████████| 624450/624450 [03:12<00:00, 3245.12it/s]
100%|██████████| 624255/624255 [03:13<00:00, 3230.35it/s]
100%|██████████| 624345/624345 [03:13<00:00, 3219.56it/s]
100%|██████████| 624360/624360 [03:12<00:00, 3246.31it/s]
100%|██████████| 624315/624315 [03:13<00:00, 3233.45it/s]
100%|██████████| 624375/624375 [03:14<00:00, 3215.74it/s]
100%|██████████| 624465/624465 [03:13<00:00, 3232.86it/s]
100%|██████████| 624345/624345 [03:13<00:00, 3221.42it/s]
100%|██████████| 624450/624450 [03:14<00:00, 3215.03it/s]
100%|██████████| 624300/624300 [03:13<00:00, 3231.92it/s]
100%|██████████| 41666/41666 [00:49<00:00, 843.25it/s]
100%|██████████| 41666/41666 [00:49<00:00, 838.67it/s]
100%|██████████| 41666/41666 [00:49<00:00, 843.79it/s]
 99%|█████████▊| 41059/41666

100%|█████████▉| 622738/624495 [00:21<00:00, 42804.92it/s]
100%|██████████| 624495/624495 [00:21<00:00, 29175.10it/s]
100%|██████████| 624300/624300 [00:21<00:00, 28691.09it/s]
100%|██████████| 624420/624420 [00:22<00:00, 27918.82it/s]
100%|██████████| 624330/624330 [00:22<00:00, 27242.48it/s]
100%|██████████| 624585/624585 [00:25<00:00, 24641.11it/s]
100%|██████████| 624480/624480 [00:24<00:00, 25201.68it/s]
100%|██████████| 624495/624495 [00:26<00:00, 23452.06it/s]
100%|██████████| 624375/624375 [00:25<00:00, 24177.25it/s]
100%|██████████| 624360/624360 [00:26<00:00, 23386.71it/s]
100%|██████████| 624315/624315 [01:11<00:00, 8763.11it/s] 
100%|██████████| 624300/624300 [03:00<00:00, 3467.32it/s]
100%|██████████| 624495/624495 [03:01<00:00, 3440.54it/s]
100%|██████████| 624360/624360 [03:01<00:00, 3430.97it/s]
100%|██████████| 624435/624435 [03:02<00:00, 3418.03it/s]
100%|██████████| 624420/624420 [03:01<00:00, 3436.52it/s]
100%|██████████| 8/8 [00:00<00:00, 387.52it/s]790.14it/s]
100

100%|██████████| 41666/41666 [00:30<00:00, 1376.87it/s]/s]
100%|██████████| 41666/41666 [00:30<00:00, 1365.54it/s]/s]
100%|██████████| 41666/41666 [00:32<00:00, 1294.36it/s]/s]
100%|██████████| 41666/41666 [00:33<00:00, 1260.37it/s]/s]
100%|██████████| 624450/624450 [00:19<00:00, 31943.58it/s]
100%|██████████| 624300/624300 [00:17<00:00, 35943.21it/s]
100%|██████████| 41560/41560 [00:30<00:00, 1376.46it/s]/s]
 39%|███▊      | 16043/41665 [00:18<00:23, 1109.23it/s]/s]
100%|██████████| 41609/41609 [00:33<00:00, 1238.93it/s]]s]
100%|██████████| 624135/624135 [00:18<00:00, 33381.63it/s]
100%|██████████| 41666/41666 [00:43<00:00, 957.21it/s] /s]
100%|██████████| 623400/623400 [00:24<00:00, 25302.53it/s]
100%|██████████| 41665/41665 [00:44<00:00, 939.04it/s] s]
100%|██████████| 41666/41666 [00:45<00:00, 914.62it/s]s]]]
 24%|██▎       | 147364/624990 [00:05<00:12, 39290.81it/s]
100%|██████████| 41666/41666 [00:48<00:00, 862.44it/s]/s]]
100%|██████████| 41665/41665 [00:49<00:00, 836.16it/s]/s]

100%|██████████| 624975/624975 [02:55<00:00, 3557.60it/s]
100%|██████████| 624990/624990 [02:56<00:00, 3547.95it/s]
100%|██████████| 624990/624990 [02:55<00:00, 3559.18it/s]
100%|██████████| 41666/41666 [00:24<00:00, 1717.42it/s]
100%|██████████| 41666/41666 [00:25<00:00, 1620.58it/s]
100%|██████████| 41666/41666 [00:26<00:00, 1577.43it/s]
100%|██████████| 41666/41666 [00:26<00:00, 1562.46it/s]
100%|██████████| 41666/41666 [00:27<00:00, 1491.37it/s]
100%|██████████| 41666/41666 [00:28<00:00, 1472.63it/s]
100%|██████████| 41666/41666 [00:28<00:00, 1438.59it/s]
100%|██████████| 41666/41666 [00:30<00:00, 1371.32it/s]
100%|██████████| 41666/41666 [00:31<00:00, 1310.93it/s]
100%|██████████| 41666/41666 [00:32<00:00, 1300.29it/s]
100%|██████████| 41666/41666 [00:32<00:00, 1296.65it/s]
100%|██████████| 41666/41666 [00:32<00:00, 1292.23it/s]
100%|██████████| 41665/41665 [00:39<00:00, 1053.72it/s]
100%|██████████| 41663/41663 [00:41<00:00, 1015.20it/s]s]
100%|██████████| 41664/41664 [00:42<00:0

100%|██████████| 624990/624990 [00:30<00:00, 20633.17it/s]
100%|██████████| 624975/624975 [02:57<00:00, 3522.00it/s]
100%|██████████| 8/8 [00:00<00:00, 1330.10it/s]01.99it/s]
100%|██████████| 8/8 [00:00<00:00, 990.36it/s]204.70it/s]
100%|██████████| 120/120 [00:00<00:00, 51500.71it/s]
100%|██████████| 120/120 [00:00<00:00, 4361.76it/s]7it/s]
100%|██████████| 624990/624990 [02:56<00:00, 3531.12it/s]
100%|██████████| 624990/624990 [02:56<00:00, 3531.81it/s]
100%|██████████| 624990/624990 [02:56<00:00, 3532.33it/s]
100%|██████████| 624945/624945 [02:58<00:00, 3508.97it/s]
100%|██████████| 624990/624990 [02:57<00:00, 3524.92it/s]
100%|██████████| 624990/624990 [02:57<00:00, 3520.62it/s]
100%|██████████| 624930/624930 [02:57<00:00, 3515.51it/s]
100%|██████████| 624990/624990 [02:56<00:00, 3534.22it/s]
100%|██████████| 624990/624990 [02:57<00:00, 3519.46it/s]
100%|██████████| 624960/624960 [02:56<00:00, 3533.78it/s]
100%|██████████| 624990/624990 [02:57<00:00, 3527.00it/s]
100%|██████████| 4