# Make cache (MLM)

In [None]:
# pip install boto3
# pip install requests
# pip install tqdm

import os
import six
import time
import random
import collections

try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle

import glob

from tokenization import BertTokenizer
import tokenization as tokenization

## 4. write_instance_to_example_files

In [None]:
class TrainingInstance(object):
    """A single training instance (sentence pair)."""
    def __init__(self, input_ids, input_mask, segment_ids, 
                 masked_lm_ids_maxseq,
                 next_sentence_labels):
        self.input_ids=input_ids, 
        self.input_mask=input_mask, 
        self.segment_ids=segment_ids, 
        self.masked_lm_ids_maxseq=masked_lm_ids_maxseq, 
        self.next_sentence_labels=next_sentence_labels

In [None]:
def write_instance_to_example_files(instances, 
                                    tokenizer, 
                                    vocab_words,
                                    max_seq_length,
                                    max_predictions_per_seq, 
                                    outfilename):
    features = []
    
    pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"])
    pad_id = pad_id[0]
    
    for (inst_index, instance) in enumerate(instances):
        
        input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)        
        input_mask = [1] * len(input_ids)
        segment_ids = list(instance.segment_ids)
        assert len(input_ids) <= max_seq_length

        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length


        next_sentence_label = [0]*max_seq_length
        next_sentence_label[0] = 1 if instance.is_random_next else 0
        
        labels_tmp = instance.masked_lm_labels
        
        masked_lm_labels = [0]*max_seq_length
        for m in range(len(masked_lm_labels)):
            if labels_tmp[m]!=0:
                masked_lm_labels[m] = vocab_words.w_to_i[labels_tmp[m]]

        features.append(
            TrainingInstance(
                input_ids=input_ids, 
                input_mask=input_mask, 
                segment_ids=segment_ids, 
                masked_lm_ids_maxseq=masked_lm_labels, 
                next_sentence_labels=next_sentence_label
            )
        )
    
    if len(features)!=0:
        with open(outfilename, 'wb') as output:
            pickle.dump(features, output, pickle.HIGHEST_PROTOCOL)

## 3. create_masked_lm_predictions

In [None]:
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
                                          ["index", "label"])

In [None]:
def create_masked_lm_predictions(tokens, masked_lm_prob,
                                 max_predictions_per_seq, rng, vocab_words, max_seq_length):
    """Creates the predictions for the masked LM objective."""

    cand_indexes = []
    for (i, token) in enumerate(tokens):
        if token == "[CLS]" or token == "[SEP]":
            continue
        cand_indexes.append(i)

    rng.shuffle(cand_indexes)    
    output_tokens = list(tokens)
    
    num_to_pmax = max(1, int(round(len(tokens) * masked_lm_prob)))
    num_to_predict = min(max_predictions_per_seq, num_to_pmax)
    
    # masked lm token 목록
    masked_lms = []
    covered_indexes = set()
    
    for index in cand_indexes:
        if len(masked_lms) >= num_to_predict:
            break
            
        if index in covered_indexes:
            continue
        covered_indexes.add(index)

        masked_token = None
        # 80% of the time, replace with [MASK]
        if rng.random() < 0.8:
            masked_token = "[MASK]"
        else:
            # 10% of the time, keep original
            if rng.random() < 0.5:
                masked_token = tokens[index]
              # 10% of the time, replace with random word
            else:
                while(True):
                    randome_mask_loc = rng.randint(0, len(vocab_words.i_to_w) - 1)
                    masked_token = vocab_words.i_to_w[randome_mask_loc]
                    if "unused" not in masked_token:
                        break
                
        output_tokens[index] = masked_token        
        masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
    
    masked_lms = sorted(masked_lms, key=lambda x: x.index)

    masked_lm_labels = [0]*max_seq_length
    for p in masked_lms:
        masked_lm_labels[p.index] = p.label
    
    return (output_tokens, masked_lm_labels)

In [None]:
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
    """Truncates a pair of sequences to a maximum sequence length."""
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_num_tokens:
            break

        trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
        assert len(trunc_tokens) >= 1

        if rng.random() < 0.5:
            del trunc_tokens[0]
        else:
            trunc_tokens.pop()

## 2. create_instances_from_document
    - make features for training

In [None]:
class TrainingInstance_tmp(object):
    """A single training instance (sentence pair)."""
    def __init__(self, tokens, segment_ids, masked_lm_labels,
               is_random_next):
        self.tokens = tokens
        self.segment_ids = segment_ids
        self.is_random_next = is_random_next
        self.masked_lm_labels = masked_lm_labels

In [None]:
def create_instances_from_document(
    all_documents, document_index, max_seq_length, short_seq_prob,
    masked_lm_prob, max_predictions_per_seq, rng, vocab_words, tokenizer):
    
    """Creates `TrainingInstance`s for a single document."""
    
    # make data for MLM task within one patient's records
    # document contains multiple records of a patient
    document = all_documents[document_index]

    # [CLS], [SEP], [SEP]
    max_num_tokens = max_seq_length - 3

    target_seq_length = max_num_tokens
    if rng.random() < short_seq_prob:
        target_seq_length = rng.randint(2, max_num_tokens)
    
    instances = []
    
    current_chunk = []
    
    current_length = 0

    max_sampling = 40
    if len(document)<=max_sampling:
        i = 0
    else:
        i = rng.randint(0, len(document) - 1 - max_sampling)
    
    while i < len(document):
        segment = document[i]
        current_chunk = current_chunk + segment
        current_length += len(segment)
        
        if i == len(document) - 1 or current_length >= target_seq_length:
            if current_chunk:
                is_random_next = False
                a_end = rng.randint(1, len(current_chunk) - 1)
                
                tokens_a = current_chunk[0:a_end]                
                tokens_b = current_chunk[a_end:len(current_chunk)]
                
                truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)

                # tokens: 
                # segment_ids: segment A:0, segment B: 1
                tokens = []
                segment_ids = []
                tokens.append("[CLS]")
                segment_ids.append(0)
                for token in tokens_a:
                    tokens.append(token)
                    segment_ids.append(0)

                tokens.append("[SEP]")
                segment_ids.append(0)

                for token in tokens_b:
                    tokens.append(token)
                    segment_ids.append(1)
                tokens.append("[SEP]")
                segment_ids.append(1)
                
                (tokens, masked_lm_labels) = create_masked_lm_predictions(
                    tokens, masked_lm_prob, max_predictions_per_seq, rng, vocab_words, max_seq_length)
            
                instance = TrainingInstance_tmp(
                    tokens=tokens,
                    segment_ids=segment_ids,
                    is_random_next=is_random_next,
                    masked_lm_labels=masked_lm_labels)
                instances.append(instance)
            
            current_chunk = []
            current_length = 0
            
        i += 1
    
#     print("len(instances): ", len(instances))
    
    
    return instances

## 1. Create Training instances

In [None]:
def create_training_instances(input_files, tokenizer, vocab_words, max_seq_length,
                              dupe_factor, short_seq_prob, masked_lm_prob, 
                              max_predictions_per_seq, rng, data_species, do_lower_case):
    """Create `TrainingInstance`s from raw text."""
    ###########################################################
    def read_documents(targetfile):
        f = open(targetfile, "r")
        lines = f.readlines()
        
        document = []
        
        stack = []
        for l in range(len(lines)):
            line = lines[l].strip("\n").strip()
            line = convert_to_unicode(line) # 221019 추가
            
            if line!="":
                if do_lower_case==True:
                    line = line.lower()
                    
                tokenized = tokenizer.tokenize(line) # bertbase
                stack.append(tokenized)
            
            else:
                if len(stack)!=0:
                    document.append(stack)
                    stack = []
        
        if len(stack)!=0:
            document.append(stack)
            stack = []
        
        return document
    ###########################################################
    
    documents = []
    
    print("reading...")
    for i in range(len(input_files)):
        # One patient's all records are are concatenated to treat as a one doc. 
        pt_doc = read_documents(input_files[i])
        concat_doc = []
        for d in range(len(pt_doc)):
            concat_doc = concat_doc + pt_doc[d]
        
        # all patient records
        documents.append(concat_doc)
    print("done read")

    instances = []    
    
    # 10 times
    for dup in range(dupe_factor):        
        print("dupe time: ", dup)
        cutlength = 512-3 # cls, sep, sep
        
        arranged_documents = []
        # documents == jump == the number of patients in a input
        for d in range(len(documents)):
            doc = documents[d]
            
            new_doc = []
            
            doc_tokens = []
            
            # segments loop
            for c in range(len(doc)):
                doc_tokens = doc_tokens + doc[c]
            
            # make segment (max length: 512==cutlength)
            for j in range(0, len(doc_tokens), cutlength):
                segment = doc_tokens[j:j+cutlength]
                if len(segment)>=3: 
                    new_doc.append(segment)
                
            arranged_documents.append(new_doc)
            
    
        # make a instance
        for document_index in range(len(arranged_documents)):
            if (document_index+1)%100==0:
                print("document_index: ", (document_index+1), "/", len(arranged_documents))
            
            instances.extend(
                create_instances_from_document(
                    arranged_documents, document_index, max_seq_length, short_seq_prob,
                    masked_lm_prob, max_predictions_per_seq, rng, vocab_words, tokenizer))
            
            
    print("len(instances): ", len(instances))
    rng.shuffle(instances)
    
    return instances


    

def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")
        
        
class Vocab_words(object):
    def __init__(self, vocab_file):
        self.i_to_w = {}
        self.w_to_i = {}
        self.getvocab(vocab_file)
#         print(self.w_to_i)

    def getvocab(self, vocab_file):
        f = open(vocab_file, 'r')
        lines = f.readlines()
        for l in range(len(lines)):
            #term = lines[l].strip("\n")
            term = lines[l].strip()
            term = convert_to_unicode(term)
            self.i_to_w[l] = term
            self.w_to_i[term] = l

# main()

In [None]:
def main(
    # required
    input_files, 
    vocab_file, 
    outdir,
    
    do_lower_case=True,  # Cased model
    max_seq_length=512, 
    max_predictions_per_seq = 20, 
    random_seed=12345, 
    dupe_factor = 1, 
    masked_lm_prob = 0.15,
    short_seq_prob = 0.1,
    data_species = ""
    ):
        
    print("vocab_file: ", vocab_file)
    tokenizer = BertTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case, max_len=max_seq_length)
    vocab_words = Vocab_words(vocab_file)
    #print("vocab_words.w_to_i: ", vocab_words.w_to_i)
    
    print("data_species: ", data_species)
    print("max_predictions_per_seq: ", max_predictions_per_seq)
    
    # random random seed
    randseed = random.randint(1, 1000)
    print("randseed: ", randseed)
    rng = random.Random(randseed)
    
    print("len(input_files): ", len(input_files))
    rng.shuffle(input_files)
    
    # jump: the number of patients in a cache file
    jump = 100
    for i in range(0, len(input_files), jump):
        sampled_inputfiles = input_files[i:i+jump]
        print("len(sampled_inputfiles): ", len(sampled_inputfiles))
        print(i, "~", i+jump, "/", len(input_files))
        
        instances = create_training_instances(
            sampled_inputfiles, tokenizer, vocab_words, max_seq_length, dupe_factor,
            short_seq_prob, masked_lm_prob, max_predictions_per_seq, rng, data_species, do_lower_case)

        idx_range = (i+jump)
        filename = str("".join(['0']*(8-len(str(idx_range)))) + str(int(idx_range)))+".cache"        
        print("filename: ", filename)
        
        write_instance_to_example_files(
                        instances=instances, 
                        tokenizer=tokenizer,
                        vocab_words=vocab_words,
                        max_seq_length=max_seq_length,
                        max_predictions_per_seq=max_predictions_per_seq, 
                        outfilename=outdir+filename)

# Generate data

In [None]:
import glob
import os

print("Generating Features...")

upper_path = ".."
target_paths = ["../00_data4pretrain/SNUH_visit_2011to2020"]
vocab_path =  "../../../otherberts/bertbase_cased"

# folder loop
files = []
for t in range(len(target_paths)):
    patients = []
    
    print("target_paths[t]: ", target_paths[t])
    
    # groups
    groups = glob.glob(upper_path + "/" + target_paths[t]+"/*")
    print("len(groups): ", len(groups))
    
    for g in range(len(groups)):
        patients = patients + glob.glob(groups[g]+"/*.txt")
    print("len(patients): ", len(patients))
    
    files = files + patients
patients = []

out_directory = "./cache/"
if not os.path.exists(out_directory):
    os.makedirs(out_directory)

main(
    # required
    input_files = files, 
    outdir = out_directory,
    vocab_file = vocab_path+'/vocab.txt', 

    # optional
    do_lower_case = False,
    max_seq_length = 512, 
    max_predictions_per_seq = 76,
    dupe_factor = 10,
    masked_lm_prob = 0.15,
    short_seq_prob = 0.1
)
print("finished")