In [3]:
import torch
from torch.nn import Embedding

import random
import torch
from transformers import BertTokenizer



class SimpleTokenizer:

    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    # end

    '''
    {
        "tokens_id":[
            101,7592,2026,2171,2003,2198,3835,2000,3113,2017,2651,2003,1037,2204,2154,2003,2025,2009,102,7592,1045,2572,5914,2034,2051,2000,2156,2017,102,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        "masks":[
            false,false,false,false,false,false,false,true,false,false,false,false,false,false,false,false,true,false,false,false,false,false,false,false,false,false,false,true,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false],
        "segments":[
            0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        "attentions":[
            1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        "positions":[
            0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63],
        "is_next": true    
    }
    '''
    #TODO: more than max_length?
    def generate_training_embedding(self, seq_a, seq_b, probs_mask=0.15, max_length=64, is_next=True):
        tokens_a = seq_a.split()
        tokens_b = seq_b.split()



        tokens_pair = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]']
        indexs_mask_all = [i + 1 for i in range(len(tokens_a))] + [i + 2 + len(tokens_a) for i in range(len(tokens_b))]
        random.shuffle(indexs_mask_all)
        indexs_masked = indexs_mask_all[:int(len(indexs_mask_all) * probs_mask)]

        len_all = len(tokens_a) + len(tokens_b) + 3
        tokens_pad = ['[PAD]' for i in range(max_length - len_all)]
        tokens_all = tokens_pair + tokens_pad

        t_segments_all = torch.IntTensor([0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)] + [0 for _ in range(len(tokens_pad))])
        t_attentions_all = torch.IntTensor([1 for _ in range(len(tokens_pair))] + [0 for _ in range(len(tokens_pad))])
        t_attentions_all[indexs_masked] = 0
        t_masks = torch.zeros(len(tokens_all), dtype=torch.bool)
        t_masks[indexs_masked] = True
        t_position_all = torch.IntTensor([i for i in range(len(tokens_all))])
        t_tokens_id = torch.IntTensor(self.tokenizer.convert_tokens_to_ids(tokens_all))

        t_isnext = torch.BoolTensor([is_next])

        return {
            'tokens_id': t_tokens_id,
            'masks': t_masks,
            'segments': t_segments_all,
            'attentions': t_attentions_all,
            'positions': t_position_all,
            'is_next': t_isnext
        }
    # end
# end

In [7]:
class SimpleBatchMaker:
    @classmethod
    def make_batch(cls, list_dict_info):
        keys_dict = list_dict_info[0].keys()
        
        dict_merged = {}
        for key_dict in keys_dict:
            target_items = [dict_info[key_dict] for dict_info in list_dict_info]
            target_items_new = [item[None, :] for item in target_items]
            dict_merged[key_dict] = torch.cat(target_items_new, dim=0)
        # end
        
        return dict_merged
    # end
# end

In [8]:
# sample_1
seq_first_1 = 'hello my name is john nice to meet you today is a good day is not it'
seq_second_1 = 'hello i am marry first time to see you'
is_next_1 = True

# sample_2
seq_first_2 = 'hello my name is hello kitty'
seq_second_2 = 'today is a good day for work and i go to the office'
is_next_2 = False

tokenizer = SimpleTokenizer()
sample_1 = tokenizer.generate_training_embedding(seq_first_1, seq_second_1, is_next=is_next_1)
sample_2 = tokenizer.generate_training_embedding(seq_first_2, seq_second_2, is_next=is_next_2)

samples = SimpleBatchMaker.make_batch([sample_1, sample_2])



torch.Size([2, 1])