In [1]:
from transformers import DistilBertModel, DistilBertForMaskedLM, DistilBertTokenizerFast
from transformers.configuration_utils import PretrainedConfig
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
from torch import nn
import json
import random
import re
import datasets

In [2]:
class Collator_BERT:
    def __init__(self, tokenizer, size_seq_max, need_masked=0.3):
        self.tokenizer = tokenizer
        self.size_seq_max = size_seq_max
        self.need_masked = need_masked
        
        index_special_token_2_id = {k:v for k,v in zip(tokenizer.all_special_tokens,tokenizer.all_special_ids)}
        
        self.id_pad = index_special_token_2_id['[PAD]']
        self.id_mask = index_special_token_2_id['[MASK]']
        self.id_cls = index_special_token_2_id['[CLS]']
        self.id_sep = index_special_token_2_id['[SEP]']
        self.id_unk = index_special_token_2_id['[UNK]']
        
        self.regex_special_token = re.compile(r'\[(PAD|MASK|CLS|SEP|EOL|UNK)\]')
    # end
    
    def _preprocess(self, line):
        line = re.sub(self.regex_special_token, r'<\1>', line)
        line = re.sub(r'''('|"|`){2}''', '', line)
        line = re.sub(r'\.{2,3}', '', line)
        line = re.sub(r' {2,}', ' ', line)
        line = line.lstrip().rstrip()
        return line
    # end
    

    def __call__(self, list_sequence_batch):
        list_sequence_batch = [self._preprocess(sequence) for sequence in list_sequence_batch]   # remove special tokens
        
        list_sequence_tokenized = self.tokenizer.batch_encode_plus(list_sequence_batch, add_special_tokens=False)['input_ids']
        
        # Process I. 
        list_list_tokenized = []
        
        # batch initialized condition
        list_tokenized_cache = []
        len_tokenized_accumulated = 2 # add cls and sep
        
        while list_sequence_tokenized:
            tokenized_poped = list_sequence_tokenized.pop(0)
            len_tokenized_current = len(tokenized_poped)
            
            if len_tokenized_accumulated + len_tokenized_current > self.size_seq_max:
                if list_tokenized_cache:
                    list_list_tokenized.append(list_tokenized_cache)
                
                    # clear
                    list_tokenized_cache = []
                    len_tokenized_accumulated = 2
                # end
            # end

            list_tokenized_cache.append(tokenized_poped)
            len_tokenized_accumulated += len_tokenized_current
        # end
        
        list_list_tokenized.append(list_tokenized_cache)
        
        
        # Process II. Merge list_tokenized
        list_tokenized_merged = []
        
        for list_tokenized in list_list_tokenized:
            # tokenized_merged = [token for tokenized_padded in [tokenized + [self.id_eol] for tokenized in list_tokenized] for token in tokenized_padded]
            tokenized_merged = [token for tokenized in list_tokenized for token in tokenized][:self.size_seq_max-2]
            list_tokenized_merged.append(tokenized_merged)
        # end
        
        
        # Process III. Add begin and stop special token, same as jinyuj_transformers_quora.ipynb
        tokens_input_encoder = []
        tokens_input_decoder = []
        tokens_label_decoder = []
        
        for tokenized_merged in list_tokenized_merged:
            tokens_input_encoder.append([self.id_cls] + tokenized_merged + [self.id_sep])
            tokens_input_decoder.append([self.id_cls] + tokenized_merged)
            tokens_label_decoder.append(tokenized_merged + [self.id_sep])
        # end
        
        inputs_encoder, masks_encoder, segments_encoder, labels_encoder = self.pad_sequences(tokens_input_encoder, self.size_seq_max, need_masked=self.need_masked)
        inputs_decoder, masks_decoder, segments_decoder, _ = self.pad_sequences(tokens_input_decoder, self.size_seq_max, need_diagonal=True)
        labels_decoder, masks_label, segments_label, _ = self.pad_sequences(tokens_label_decoder, self.size_seq_max)

        
        return Batch(
            input_ids=inputs_encoder,  # contains [mask]s
            attention_mask=masks_encoder,
            labels=labels_encoder,  # doesn't contain [mask]
            output_hidden_states=False,
            segments=segments_encoder
        )

    # end


    # return masks_attention?, return masks_segment?
    def pad_sequences(self, sequences, size_seq_max, need_diagonal=False, need_masked=0): # need_diagonal and need_masked cannot both set, one for bert seq one for s2s seq
        id_pad = self.id_pad
        id_mask = self.id_mask

        sequences_padded = []
        sequences_masked_padded = []

        for sequence in sequences:
            len_seq = len(sequence)

            count_pad = size_seq_max - len_seq

            sequence = torch.LongTensor(sequence)
            sequence_padded = torch.cat((sequence, torch.LongTensor([id_pad] * count_pad)))
            # print(sequence_padded.shape)
            sequences_padded.append(sequence_padded)

            if need_masked:
                index_masked = list(range(1, len_seq-1))
                random.shuffle(index_masked)
                index_masked = torch.LongTensor(index_masked[:int(need_masked * (len_seq-2))])

                sequence_masked = sequence.detach().clone()
                sequence_masked.index_fill_(0, index_masked, id_mask)
                sequence_masked_padded = torch.cat((sequence_masked, torch.LongTensor([id_pad] * count_pad)))
                
                sequences_masked_padded.append(sequence_masked_padded)
            # end
    #   # end for

        inputs = torch.stack(sequences_padded)  # (batch, size_seq_max)
        if need_masked:
            inputs_masked_padded = torch.stack(sequences_masked_padded)
        # end

        masks_segment = (inputs != self.id_pad).unsqueeze(-2)    #(nbatch, 1, seq)
        masks_attention = self.make_std_mask(inputs, self.id_pad) if need_diagonal else masks_segment

        if need_masked:
            masks_masked = (inputs_masked_padded != id_mask).unsqueeze(-2)
            masks_attention = masks_attention & masks_masked
            return inputs_masked_padded, masks_attention, masks_segment, inputs # (inputs, masks_attention, masks_segment, labels)
        else:
            return inputs, masks_attention, masks_segment, None
        # end
    # end


    def subsequent_mask(self, size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
            torch.uint8
        )
        return subsequent_mask == 0
    # end

    
    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & self.subsequent_mask(tgt.size(-1)).type_as(
            tgt_mask.data
        )
        return tgt_mask
    # end
# end

In [3]:
def BookCorpus2000(split=0.1):
    filename = 'bookcorpus_2000.json'
    
    with open(filename, 'r') as file:
        list_corpus = json.load(file)
    # end
    
    indexs_all = list(range(len(list_corpus)))
    random.shuffle(indexs_all)
    
    index_split = int(split * len(list_corpus))
    
    indexs_eval = indexs_all[:index_split]
    indexs_train = indexs_all[index_split:]
    
    list_corpus_eval = [list_corpus[i_e] for i_e in indexs_eval]
    list_corpus_train = [list_corpus[i_t] for i_t in indexs_train]
    
    return list_corpus_train, list_corpus_eval, None
# end


def BookCorpus(split=0.0001, used=-1):
    import datasets
    
    list_corpus = datasets.load_dataset('bookcorpus')['train']['text'][:used]   # 70,000,000, 70 Million
    
    indexs_all = list(range(len(list_corpus)))
    random.shuffle(indexs_all)
    
    index_split = int(split * len(list_corpus))
    
    indexs_eval = indexs_all[:index_split]
    indexs_train = indexs_all[index_split:]
    
    list_corpus_eval = [list_corpus[i_e] for i_e in indexs_eval]
    list_corpus_train = [list_corpus[i_t] for i_t in indexs_train]
    
    return list_corpus_train, list_corpus_eval, None
# end


class Batch:
    DEVICE = 'cuda'

    def __init__(self, **kwargs):
        self.kwargs = {}
        for k, v in kwargs.items():
            if v is not None and type(v) is not bool:
                self.kwargs[k] = v.to(Batch.DEVICE)
        # end
    # end

    def __call__(self):
        return self.kwargs
    # end
# end

class Dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
    def __iadd__(self, other):
        for k, v in self.items():
            if k in other and other[k]:
                self[k] += other[k]
            # end
        # end

        return self
    # end
# end


def get_info_accuracy_template_mlm():
    return Dotdict({
        'corrects_segmented': 0,
        'corrects_masked': 0,
        'num_segmented': 0,
        'num_masked': 0 
    })    
# end


In [4]:
import transformers
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')


gpu = 1
torch.cuda.set_device(gpu)
#### @jingdi start
dim = 768
seq_max = 512
dropout=0.1
n_layers=6
n_heads=12
batch_size=16
epoch = 1
############ ends

# config_pretrained = PretrainedConfig(
#     vocab_size=tokenizer.vocab_size,
#     dim=dim,
#     dropout=dropout,
#     max_position_embeddings=seq_max,
#     attention_dropout=dropout,
#     n_layers=n_layers,
#     num_hidden_layers=n_layers,
#     n_heads=n_heads,
#     hidden_dim=dim,
#     activation='relu',
#     initializer_range=0.02,
#     sinusoidal_pos_embds=True,
# )

model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased').to('cuda')
# model = DistilBertModel.from_pretrained("distilbert-base-uncased").to('cuda')

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
    # end
# end

#### @jingdi start
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
# optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
# decayRate = 0.96
# # lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)
# lr_scheduler = transformers.get_scheduler(
#     name="cosine_with_restarts", optimizer=optimizer, num_warmup_steps=1000000, num_training_steps=624375 * 10
# )
############ ends

In [5]:
collator = Collator_BERT(tokenizer, seq_max)
source_train, source_eval, _  = BookCorpus2000()  # @jingdi
dataloader_train = DataLoader(source_train, batch_size, shuffle=False, collate_fn=collator)
dataloader_eval = DataLoader(source_eval, 1, shuffle=False, collate_fn=collator)
func_loss = torch.nn.CrossEntropyLoss()

In [7]:
info_acc_train = get_info_accuracy_template_mlm()
info_acc_epoch = get_info_accuracy_template_mlm()
losss_train = []


# train phase
model.eval()
for b, batch in enumerate(tqdm(dataloader_train)):
    info_batch = batch()

    segments_encoder = info_batch['segments']
    masks_encoder = info_batch['attention_mask']
    labels_mlm = info_batch['labels']
    del info_batch['segments']
    del info_batch['labels']

    output_mlm = model(**info_batch).logits
    # output_mlm = model(**info_batch)['last_hidden_state']

    info_acc = get_info_accuracy_template_mlm()

    segments_encoder_2d = segments_encoder.transpose(-1,-2)[:,:,0]
    hidden_mlm_segmented = output_mlm.masked_select(segments_encoder_2d.unsqueeze(-1)).reshape(-1, output_mlm.shape[-1]) # should be (segmented_all_batchs, size_vocab)

    info_acc.corrects_segmented = torch.sum(hidden_mlm_segmented.argmax(-1) == labels_mlm.masked_select(segments_encoder_2d)).cpu().item()
    info_acc.num_segmented = hidden_mlm_segmented.shape[0]

    masks_masked = torch.logical_xor(masks_encoder, segments_encoder) & segments_encoder # True is masked
    masks_masked_perbatch = masks_masked[:,0,:]
    hidden_mlm_masked = output_mlm.masked_select(masks_masked_perbatch.unsqueeze(-1)).reshape(-1, output_mlm.shape[-1])

    info_acc.corrects_masked = torch.sum(hidden_mlm_masked.argmax(-1) == labels_mlm.masked_select(masks_masked_perbatch)).cpu().item()
    info_acc.num_masked = hidden_mlm_masked.shape[0]

    info_acc_epoch += info_acc


# lr_scheduler.step()  # scheduler step per epoch, @jingdi
# loss_average_per_e = sum(losss_per_e) / len(losss_per_e)
info_acc_train += info_acc_epoch
# losss_train += losss_per_e


  

  0%|          | 0/113 [00:00<?, ?it/s]


In [None]:
hidden_mlm_segmented.argmax(-1)

In [10]:
model(**info_batch).logits.shape

torch.Size([1, 512, 30522])

In [None]:
labels_mlm.shape

In [None]:
segments_encoder_2d.shape

In [None]:
output_mlm.shape[-1]

In [None]:
output_mlm.shape

In [None]:
labels_mlm.shape

In [None]:
labels_mlm

In [None]:
labels_mlm.masked_select(segments_encoder_2d)

In [None]:
segments_encoder_2d

In [None]:
model

In [None]:
# model.eval()

# info_acc_eval = get_info_accuracy_template_mlm()

# for e in range(1):
#     info_acc_epoch = get_info_accuracy_template_mlm()
    
#     for b, batch in enumerate(tqdm(dataloader_eval)):
#         info_batch = batch()
        
#         segments_encoder = info_batch['segments']
#         masks_encoder = info_batch['attention_mask']
#         labels_mlm = info_batch['labels']
#         del info_batch['segments']
        
#         with torch.no_grad():
#             output_mlm = model(**info_batch).logits
#         # end

#         info_acc = get_info_accuracy_template_mlm()
        
#         segments_encoder_2d = segments_encoder.transpose(-1,-2)[:,:,0]
#         hidden_mlm_segmented = output_mlm.masked_select(segments_encoder_2d.unsqueeze(-1)).reshape(-1, output_mlm.shape[-1]) # should be (segmented_all_batchs, size_vocab)
        
#         info_acc.corrects_segmented = torch.sum(hidden_mlm_segmented.argmax(-1) == labels_mlm.masked_select(segments_encoder_2d)).cpu().item()
#         info_acc.num_segmented = hidden_mlm_segmented.shape[0]
        
#         masks_masked = torch.logical_xor(masks_encoder, segments_encoder) & segments_encoder # True is masked
#         masks_masked_perbatch = masks_masked[:,0,:]
#         hidden_mlm_masked = output_mlm.masked_select(masks_masked_perbatch.unsqueeze(-1)).reshape(-1, output_mlm.shape[-1])
        
#         info_acc.corrects_masked = torch.sum(hidden_mlm_masked.argmax(-1) == labels_mlm.masked_select(masks_masked_perbatch)).cpu().item()
#         info_acc.num_masked = hidden_mlm_masked.shape[0]
        
#         info_acc_epoch += info_acc
#     # end
    
#     info_acc_eval += info_acc_epoch

# # end

# print(
#     'Eval ends. Result:  acc_mlm: {}'.format(
#         info_acc_eval.corrects_masked / info_acc_eval.num_masked
#     )
# )