In [39]:
import numpy as np
import torch
from torch.nn import Embedding

import random
from transformers import BertTokenizer






class Simple_BertEmbedder(torch.nn.Module):

    def __init__(self, length_seq_max, length_vocab_max, length_segment_max, dim_hidden, prob_drop=0.15):
        super().__init__()
        self.length_seq_max = length_seq_max
        self.length_vocab_max = length_vocab_max
        self.length_segment_max = length_segment_max
        self.dim_hidden = dim_hidden

        self.embedding_token = Embedding(length_vocab_max, dim_hidden)  # num_embedding, embedding_dim
        self.embedding_position = Embedding(length_seq_max, dim_hidden)
        self.embedding_segment = Embedding(length_segment_max, dim_hidden)
        self.norm = torch.nn.LayerNorm(dim_hidden)
        self.dropout = torch.nn.Dropout(p=prob_drop)

    # end

    # TODO: segment embedding 0,0,0,1,1,1,0,0,0
    def forward(self, token_ids=None, masked=None, segments=None, attentions=None, positions=None, is_next=None):
        
        # print(token_ids)
        e_token = self.embedding_token(token_ids)
        e_segment = self.embedding_segment(segments)
        e_position = self.embedding_position(positions)
        
        sum_embedding = self.norm(e_token + e_segment + e_position)
        return self.dropout(sum_embedding)
    # end
# end


class Simple_SelfAttention(torch.nn.Module):

    def __init__(self, dim_hidden, num_head=6):
        super().__init__()
        self.dim_hidden = dim_hidden
        self.num_head = num_head
        self.linear_Q = torch.nn.Linear(dim_hidden, dim_hidden * num_head)
        self.linear_K = torch.nn.Linear(dim_hidden, dim_hidden * num_head)
        self.linear_V = torch.nn.Linear(dim_hidden, dim_hidden * num_head)
        self.linear_out = torch.nn.Linear(dim_hidden * num_head, dim_hidden)

    # end

    # attentions is masked already
    def forward(self, seq_in, attentions):  # batch_size, len_seq, dim_in(dim_in = dim_embedding for first layer)
        size_batch, len_seq, dim_in = seq_in.shape
        dim_hidden = self.dim_hidden
        w_q = self.linear_Q(seq_in)
        w_k = self.linear_K(seq_in)
        w_v = self.linear_V(seq_in)

        q = w_q.view(size_batch, len_seq, -1, dim_hidden).transpose(1, 2)
        k = w_k.view(size_batch, len_seq, -1, dim_hidden).transpose(1, 2)
        v = w_v.view(size_batch, len_seq, -1, dim_hidden).transpose(1, 2)
        
        w_qk = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(dim_hidden)  # batch_size, num_head, len_seq, len_seq
        
        attentions_all = attentions[:, None, None, :].expand_as(w_qk)  # batch_size, num_head, len_seq, len_seq
        
        masks_all = (1 - attentions_all) * -1e10
        scores_raw = w_qk + masks_all
        scores = torch.nn.functional.softmax(scores_raw, dim=-1)  # batch_size, num_head, len_seq, len_seq
        
        z = torch.matmul(scores, v)  # batch_size, num_head, len_seq, dim_hidden
        w_z = z.transpose(1, 2).contiguous().view(size_batch, len_seq, -1)  # batch_size, len_seq, num_head * dim_hidden

        seq_out = self.linear_out(w_z)  # batch_size, len_seq, dim_hidden
        return seq_out, scores
    # end
# end


class Simple_Positionwise_FeedforwardNet(torch.nn.Module):
    def __init__(self, dim_hidden, dim_network=None, proba_drop=0.15):
        super().__init__()

        if dim_network is None:
            dim_network = dim_hidden
        # end

        self.linear_1 = torch.nn.Linear(dim_hidden, dim_network)
        self.activation_1 = torch.nn.ReLU()
        self.dropout_1 = torch.nn.Dropout(p=proba_drop)
        self.linear_out = torch.nn.Linear(dim_network, dim_hidden)

    # end

    def forward(self, seq_in):
        return self.linear_out(self.dropout_1(self.activation_1(self.linear_1(seq_in))))
    # end
# end

class Simple_NormResidual(torch.nn.Module):
    def __init__(self, dim_hidden):
        super().__init__()
        self.norm = torch.nn.LayerNorm(dim_hidden)

    # end

    def forward(self, origin, target):
        return self.norm(origin + target)
    # end

# end

class Simple_TransformerEncoder(torch.nn.Module):
    def __init__(self, dim_hidden):
        super().__init__()

        self.layer_selfattention = Simple_SelfAttention(dim_hidden)
        self.layer_positionwise_feedforwardnet = Simple_Positionwise_FeedforwardNet(dim_hidden)
        self.layer_norm1 = Simple_NormResidual(dim_hidden)
        self.layer_norm2 = Simple_NormResidual(dim_hidden)

    # end

    def forward(self, seq_in, attentions):
        seq_attention, scores_attention = self.layer_selfattention(seq_in, attentions)
        seq_norm1 = self.layer_norm1(seq_in, seq_attention)
        seq_feedforwardnet = self.layer_positionwise_feedforwardnet(seq_norm1)
        seq_norm2 = self.layer_norm2(seq_norm1, seq_feedforwardnet)
        return seq_norm2, attentions
    # end
# end


class Simple_Bert(torch.nn.Module):
    def __init__(self, length_seq_max, length_vocab_max, length_segment_max, dim_hidden, num_encoder=6):
        super().__init__()
        self.embedder = Simple_BertEmbedder(length_seq_max, length_vocab_max, length_segment_max, dim_hidden)
        self.layers_encoder = [Simple_TransformerEncoder(dim_hidden) for i in range(num_encoder)]
        self.dim_hidden = dim_hidden
    # end

    def forward(self, token_ids=None, masked=None, segments=None, attentions=None, positions=None, is_next=None):
        seq_in = self.embedder(token_ids, masked, segments, attentions, positions, is_next)
        seq_out = seq_in
        
        for layer_encoder in self.layers_encoder:
            seq_out, _ = layer_encoder(seq_out, attentions)  # seq_out, scores(batch_size, num_head, len_seq, len_seq)
        # end

        return seq_out
    # end
# end

In [40]:
# for Masked Language Model
class Simple_BertMLMTiedDecoder(torch.nn.Module):
    '''
        # Seems transpose is not required when tying Embedding -> Linear
        a = torch.nn.Embedding(3,2)
        b = torch.nn.Linear(2,3, bias=False)
        b.weight = a.weight
        seq_a = torch.ones([1,1,3], dtype=torch.int64)
        seq_b = torch.ones([1,1,2])

        a(seq_a)    # this works
        b(seq_b)    # this also work
    '''

    def __init__(self, tied_embedder: torch.nn.Embedding):
        super().__init__()

        self.dim_in = tied_embedder.embedding_dim
        self.dim_out = tied_embedder.num_embeddings

        self.linear_1 = torch.nn.Linear(self.dim_in, self.dim_in)
        self.activation_1 = torch.nn.GELU()
        self.norm_1 = torch.nn.LayerNorm(self.dim_in)

        self.bias = torch.nn.Parameter(torch.Tensor(self.dim_out))
        self.linear_decoder = torch.nn.Linear(self.dim_in, self.dim_out, bias=False)
        self.linear_decoder.weight = tied_embedder.weight

    # end

    def forward(self, seq_in, masks):  # seq_in: batch_size, length_seq, dim_hidden?

        shape_seq = seq_in.shape  # batch_size, length_seq, dim_hidden
        masks_one = masks[:, :, None].expand_as(seq_in)

        seq_masked = torch.masked_select(seq_in, masks_one).view(1, -1, shape_seq[-1])  # TODO: 1.using gather? 2.remove batch concept?
        h_masked = self.norm_1(self.activation_1(self.linear_1(seq_masked)))

        return self.linear_decoder(h_masked) + self.bias  # batch_size, length_masked, dim_vocab
    # end
# end




# for NextSentencePrediction
class Simple_BertCLSDecoder(torch.nn.Module):
    def __init__(self, dim_in, dim_out=2):
        super().__init__()
        self.linear_1 = torch.nn.Linear(dim_in, dim_in)
        self.activation_1 = torch.nn.Tanh()
        self.linear_out = torch.nn.Linear(dim_in, dim_out)

    # end

    def forward(self, seq_in):
        seq_cls = seq_in[:, 0, :]
        h_1 = self.activation_1(self.linear_1(seq_cls))
        return self.linear_out(h_1)
    # end
# end



In [41]:
class SimpleTokenizer:

    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        
        self.length_seq_max = 64
        self.length_vocab_max = self.tokenizer.vocab_size
        self.length_segment_max = 2

    # end

    '''
    {
        "tokens_id":[
            101,7592,2026,2171,2003,2198,3835,2000,3113,2017,2651,2003,1037,2204,2154,2003,2025,2009,102,7592,1045,2572,5914,2034,2051,2000,2156,2017,102,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        "masks":[
            false,false,false,false,false,false,false,true,false,false,false,false,false,false,false,false,true,false,false,false,false,false,false,false,false,false,false,true,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false],
        "segments":[
            0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        "attentions":[
            1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1,0,1,1,1,1,1,1,1,1,1,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
        "positions":[
            0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63],
        "is_next": true    
    }
    '''
    #TODO: more than max_length?
    def generate_training_embedding(self, seq_a, seq_b, probs_mask=0.15, max_length=64, is_next=True):
        tokens_a = seq_a.split()
        tokens_b = seq_b.split()



        tokens_pair = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]']
        indexs_mask_all = [i + 1 for i in range(len(tokens_a))] + [i + 2 + len(tokens_a) for i in range(len(tokens_b))]
        random.shuffle(indexs_mask_all)
        indexs_masked = indexs_mask_all[:int(len(indexs_mask_all) * probs_mask)]

        len_all = len(tokens_a) + len(tokens_b) + 3
        tokens_pad = ['[PAD]' for i in range(max_length - len_all)]
        tokens_all = tokens_pair + tokens_pad

        t_segments_all = torch.LongTensor([0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)] + [0 for _ in range(len(tokens_pad))])
        t_attentions_all = torch.LongTensor([1 for _ in range(len(tokens_pair))] + [0 for _ in range(len(tokens_pad))])
        t_attentions_all[indexs_masked] = 0
        t_masks = torch.zeros(len(tokens_all), dtype=torch.bool)
        t_masks[indexs_masked] = True
        t_position_all = torch.LongTensor([i for i in range(len(tokens_all))])
        t_tokens_id = torch.LongTensor(self.tokenizer.convert_tokens_to_ids(tokens_all))

        t_isnext = torch.LongTensor([is_next])

        return {
            'token_ids': t_tokens_id,
            'masked': t_masks,
            'segments': t_segments_all,
            'attentions': t_attentions_all,
            'positions': t_position_all,
            'is_next': t_isnext
        }
    # end
# end

In [42]:
class SimpleBatchMaker:
    @classmethod
    def make_batch(cls, list_dict_info):
        keys_dict = list_dict_info[0].keys()
        
        dict_merged = {}
        for key_dict in keys_dict:
            target_items = [dict_info[key_dict] for dict_info in list_dict_info]
            target_items_new = [item[None, :] for item in target_items]
            dict_merged[key_dict] = torch.cat(target_items_new, dim=0)
        # end
        
        return dict_merged
    # end
# end

In [43]:
dim_hidden=128

# sample_1
seq_first_1 = 'hello my name is john nice to meet you today is a good day is not it'
seq_second_1 = 'hello i am marry first time to see you'
is_next_1 = True

# sample_2
seq_first_2 = 'hello my name is hello kitty'
seq_second_2 = 'today is a good day for work and i go to the office'
is_next_2 = False

tokenizer = SimpleTokenizer()

length_seq_max = tokenizer.length_seq_max
length_vocab_max = tokenizer.length_vocab_max
length_segment_max = tokenizer.length_segment_max


sample_1 = tokenizer.generate_training_embedding(seq_first_1, seq_second_1, is_next=is_next_1)
sample_2 = tokenizer.generate_training_embedding(seq_first_2, seq_second_2, is_next=is_next_2)

samples = SimpleBatchMaker.make_batch([sample_1, sample_2])

In [44]:
bert_1 = Simple_Bert(length_seq_max, length_vocab_max, length_segment_max, 128)

In [45]:
class SimpleBertPretrainer(torch.nn.Module):
    def __init__(self, bert):
        super().__init__()
        self.bert = bert

        self.dim_hidden = bert.dim_hidden
        self.head_nsp = Simple_BertCLSDecoder(self.dim_hidden)
        self.head_mlm = Simple_BertMLMTiedDecoder(self.bert.embedder.embedding_token)

        self.fn_loss_nsp = torch.nn.CrossEntropyLoss()
        self.fn_loss_mlm = torch.nn.CrossEntropyLoss()

    # end

    def forward(self, token_ids=None, masked=None, segments=None, attentions=None, positions=None, is_next=None):
        seq_bert = self.bert(token_ids=token_ids, masked=masked, segments=segments, attentions=attentions, positions=positions, is_next=is_next)
        h_mlm = self.head_mlm(seq_bert, masked).squeeze(0)   # merge batch and seq as batch = 1
        
        masked_ids = token_ids.masked_select(masked).to(torch.int64)
        loss_mlm = self.fn_loss_mlm(h_mlm, masked_ids).float().mean()
        
        h_nsp = self.head_nsp(seq_bert)
        loss_nsp = self.fn_loss_nsp(h_nsp, is_next.squeeze(1)).float().mean()
        
        loss_all = loss_nsp + loss_mlm
        return loss_all
    # end

# end



In [46]:
bert_pretrainer = SimpleBertPretrainer(bert_1)

In [47]:
optimizer = torch.optim.Adam(bert_pretrainer.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01)
optimizer.zero_grad()
loss_all = bert_pretrainer(**samples)
loss_all.backward()
optimizer.step()

In [48]:
torch.save(bert_1.state_dict(), 'bert_1.pt')