# Make cache

In [None]:
# pip install sentencepiece
import sentencepiece as spm
import os
import six
import time
import random
import collections

try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle

import glob

import pandas as pd

from tokenization import BertTokenizer
import tokenization as tokenization

## 4. write_instance_to_example_files

In [None]:
class TrainingInstance_ext(object):
    """A single training instance (sentence pair)."""
    def __init__(self, input_ids, input_mask, segment_ids, label_ids, start_pos, end_pos, doc_ids, sent_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.start_pos = start_pos 
        self.end_pos = end_pos
        self.doc_ids = doc_ids
        self.sent_ids = sent_ids

In [None]:
def write_instance_to_example_files(instances, 
                                    tokenizer, 
                                    vocab_words,
                                    max_seq_length,
                                    max_predictions_per_seq, 
                                    outfilename):
    
    features = []
    
    pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"])
    pad_id = pad_id[0]
    
    for (inst_index, instance) in enumerate(instances):
        input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
    
        input_mask = [1] * len(input_ids)
        segment_ids = list(instance.segment_ids)
        
        label_ids = list(instance.label_ids) # label_id
        start_pos = list(instance.start_pos) # label_id
        end_pos = list(instance.end_pos) # label_id

        doc_ids = list(instance.doc_ids)
        sent_ids = list(instance.sent_ids)
        
        # segment_id
        if segment_ids[-1]==0:
            seg_id=1
        else:
            seg_id=0
        
        assert len(input_ids) <= max_seq_length

        while len(input_ids) < max_seq_length:
            input_ids.append(pad_id)
            input_mask.append(0)       # attention mask
            segment_ids.append(seg_id) # segment_id
            label_ids.append(0)        # label
            start_pos.append(0)        # label
            end_pos.append(0)          # label
            doc_ids.append(-1)
            sent_ids.append(-1)
            
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(start_pos) == max_seq_length
        assert len(end_pos) == max_seq_length


        features.append(
            TrainingInstance_ext(
                input_ids = input_ids,
                input_mask = input_mask, 
                segment_ids = segment_ids,
                label_ids = label_ids,
                start_pos = start_pos,
                end_pos = end_pos, 
                doc_ids = doc_ids,
                sent_ids = sent_ids
            )
        )
    
    if len(features)!=0:
        with open(outfilename, 'wb') as output:
            pickle.dump(features, output, pickle.HIGHEST_PROTOCOL)

## 2. create_instances_from_document

In [None]:
class TrainingInstance_ext_tmp(object):
    """A single training instance (sentence pair)."""
    def __init__(self, tokens, segment_ids, label_ids, start_pos, end_pos, doc_ids, sent_ids):#, sep_loc):
        self.tokens = tokens
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.start_pos = start_pos 
        self.end_pos = end_pos
        self.doc_ids = doc_ids
        self.sent_ids = sent_ids

In [None]:
def truncate_seq_sides(tokens_a, doc_id_a, sent_id_a, doctype_a, section_a, label_a, 
                 tokens_b, doc_id_b, sent_id_b, doctype_b, section_b, label_b,
                 max_num_tokens, rng):
    """Truncates a pair of sequences to a maximum sequence length."""
    if (len(tokens_a)+len(tokens_b))<=max_num_tokens:
        return tokens_a, doc_id_a, sent_id_a, doctype_a, section_a, label_a, tokens_b, doc_id_b, sent_id_b, doctype_b, section_b, label_b
    
    if len(tokens_a)==0 and len(tokens_b)==0:
        return tokens_a, doc_id_a, sent_id_a, doctype_a, section_a, label_a, tokens_b, doc_id_b, sent_id_b, doctype_b, section_b, label_b
    
    while True:
        if (len(tokens_a)>len(tokens_b) and len(tokens_a)>0) or len(tokens_b)==0:
            # check empty
            if tokens_a:
                del tokens_a[0]
                del doc_id_a[0]
                del sent_id_a[0]
                del doctype_a[0]
                del section_a[0]
                del label_a[0]
        
        else:
            # check empty
            if tokens_b:
                tokens_b.pop()
                doc_id_b.pop()
                sent_id_b.pop()
                doctype_b.pop()
                section_b.pop()
                label_b.pop()
            
        if (len(tokens_a)+len(tokens_b))<=max_num_tokens:
            return tokens_a, doc_id_a, sent_id_a, doctype_a, section_a, label_a, tokens_b, doc_id_b, sent_id_b, doctype_b, section_b, label_b
        

In [None]:
def read_documents(targetfile, tokenizer, do_lower_case):
    f = open(targetfile, "r")
    lines = f.readlines()
    f.close()
    
    doc_ids = []
    sent_ids = []
    doctypes = []
    sections = []
    labels = []
    tokens_lines = []
    
    for l in range(len(lines)):
        line = lines[l].strip("\n")
        doc_id = line.split("\t")[0]
        sent_id = int(line.split("\t")[1])
        doctype = line.split("\t")[2]
        section = line.split("\t")[3]
        label   = int(line.split("\t")[4])
        content = line.split("\t")[5]
        
        if do_lower_case==True:
            content= content.lower()
            
        tokenstmp = tokenizer.tokenize(content)
        
        doc_id = str(doc_id)
        not_zero_flag = False
        for d in range(len(doc_id)):
            if doc_id!=0:
                doc_id = int(doc_id[d:])
                not_zero_flag = True
                break
        if not_zero_flag==False:
            doc_id = 0
        
        # 수집
        doc_ids.append(doc_id)
        sent_ids.append(sent_id)
        doctypes.append(doctype)
        sections.append(section)
        labels.append(label)
        tokens_lines.append(tokenstmp)
    
    return doc_ids, sent_ids, doctypes, sections, labels, tokens_lines


In [None]:
def create_instances_from_document(input_file, max_seq_length, short_seq_prob,
    masked_lm_prob, max_predictions_per_seq, rng, do_lower_case, vocab_words, mecab_sp_tokenizer, mode):
    
    instances = []
    
    ########### tokenize all documents ###########
    doc_ids, sent_ids, doctypes, sections, labels, tokens_lines = read_documents(input_file, mecab_sp_tokenizer, do_lower_case)
    
    ########### 모든 정보를 concat ###########
    doc_id_concat = []
    sent_id_concat = []
    doctype_concat = []
    section_concat = []
    label_concat = []
    token_concat = []
    
    cut_interval = [0]
    doc_ids_last = ""
    
    for t in range(len(tokens_lines)):
        if doc_ids[t]!=doc_ids_last and t!=0:
            cut_interval.append(len(token_concat))
        
        token_truncated = tokens_lines[t][:]
        
        token_concat = token_concat + token_truncated+["[SEP]"]
        doc_id_concat = doc_id_concat + [doc_ids[t]]*(len(token_truncated)+1)
        sent_id_concat = sent_id_concat + [sent_ids[t]]*(len(token_truncated)+1)
        doctype_concat = doctype_concat + [doctypes[t]]*(len(token_truncated)+1)
        section_concat = section_concat + [sections[t]]*(len(token_truncated)+1)
        label_concat = label_concat + [labels[t]]*(len(token_truncated)+1)
        
        doc_ids_last = doc_ids[t]
        
    cut_interval.append(len(token_concat))
        
    assert len(token_concat)==len(doc_id_concat)
    assert len(token_concat)==len(sent_id_concat)
    assert len(token_concat)==len(doctype_concat)
    assert len(token_concat)==len(section_concat)
    assert len(token_concat)==len(label_concat)
    
    if mode=="train":
        rnd_start = 0
        rnd_end = len(cut_interval)-2
        if rnd_end<0:
            point=0
        else:
            point = random.randint(rnd_start, rnd_end)
    else:
        point=0
    
    
    ########### sampling ###########
    for t in range(point, len(cut_interval)-1):
        start = cut_interval[t]
        end = cut_interval[t+1]
        
        # Find the start and end positions of the label
        # Afterwards, cut the left and right 512 tokens based on the found location.
        label_start_idx = 0
        label_end_idx = 0
        for l in range(len(label_concat[start:end])):
            if label_concat[start:end][l]==1:
                label_start_idx = l
                break
        for l in range(len(label_concat[start:end])-1, -1, -1):
            if label_concat[start:end][l]==1:
                label_end_idx = l
                label_end_idx = label_end_idx + 1 # 미만 범위를 추출해야 하기 때문
                break
        
        
        # If the length of assessment exceeds 512,
        # Force truncate to 100~200 tokens
        assessmentlen = label_end_idx - label_start_idx
        span = random.randint(100, 200)
        if assessmentlen>=512:
            label_end_idx = min(label_end_idx, label_start_idx + span)
        
        # truncate assessment + [SEP]
        tokens_asmt = token_concat[start:end][label_start_idx:label_end_idx]
        if tokens_asmt[-1]!="[SEP]":
            tokens_asmt[-1]="[SEP]"
        
        
        # set 512 tokens in total on the left and right side of the assessment section.
        # left
        tokens_a = token_concat[start:end][:label_start_idx]
        doc_id_a = doc_id_concat[start:end][:label_start_idx]
        sent_id_a = sent_id_concat[start:end][:label_start_idx]
        doctype_a = doctype_concat[start:end][:label_start_idx]
        section_a = section_concat[start:end][:label_start_idx]
        label_a = label_concat[start:end][:label_start_idx]

        # right
        tokens_b = token_concat[start:end][label_end_idx:]
        doc_id_b = doc_id_concat[start:end][label_end_idx:]
        sent_id_b = sent_id_concat[start:end][label_end_idx:]
        doctype_b = doctype_concat[start:end][label_end_idx:]
        section_b = section_concat[start:end][label_end_idx:]
        label_b = label_concat[start:end][label_end_idx:]
        
        # truncate left assessemnt right
        max_trunlen = max_seq_length - (label_end_idx - label_start_idx) -1 # [CLS]
        
        tokens_a, doc_id_a, sent_id_a, doctype_a, section_a, label_a,\
        tokens_b, doc_id_b, sent_id_b, doctype_b, section_b, label_b = truncate_seq_sides(
                 tokens_a, doc_id_a, sent_id_a, doctype_a, section_a, label_a, 
                 tokens_b, doc_id_b, sent_id_b, doctype_b, section_b, label_b,
                 max_trunlen, rng)
        
        # after truncating
        # token_left del sep
        if len(tokens_a)>0:
            if tokens_a[0]=="[SEP]":
                del tokens_a[0]
                del doc_id_a[0]
                del sent_id_a[0]
                del label_a[0]
            
        # token rihgt add sep
        if len(tokens_b)>0:
            if tokens_b[-1]!="[SEP]":
                tokens_b[-1] = "[SEP]"
                #doc_id_b[-1] = 
                #sent_id_b[-1] = 
                #label_b[-1] = 
                

        tokens = ["[CLS]"] + tokens_a + tokens_asmt + tokens_b
        doc_ids = [-1] + doc_id_a + doc_id_concat[start:end][label_start_idx:label_end_idx] + doc_id_b
        sent_ids = [-1] + sent_id_a + sent_id_concat[start:end][label_start_idx:label_end_idx] + sent_id_b
        label_ids = [0] + label_a + label_concat[start:end][label_start_idx:label_end_idx] + label_b
        
        
        # get position arrays
        label_start_idx = 0
        label_end_idx = 0
        start_pos = []
        end_pos = []
        for l in range(len(label_ids)):
            if label_ids[l]==1:
                label_start_idx = l
                break
        for l in range(len(label_ids)-1, -1, -1):
            if label_ids[l]==1:
                label_end_idx = l
                label_end_idx = label_end_idx + 1 # 미만범위를 출력해야 하기 떄문
                break
                
        start_pos = [0]*len(label_ids)
        start_pos[label_start_idx] = 1
        end_pos = [0]*len(label_ids)
        end_pos[label_end_idx-1] = 1 # 미만범위를 추출하느라 +1을 했었으므로, 여기에서는 인덱스를 맞춰서 -1 함
        
        # segment_ids
        seg_id = 0
        segment_ids = [0] # [CLS]
        last_sent_id = sent_ids[0]
        for s in range(1, len(sent_ids)): # [CLS] 제외한 루프
            if last_sent_id!=sent_ids[s]:
                seg_id = 1 if seg_id==0 else 0
            segment_ids.append(seg_id)
            last_sent_id = sent_ids[s]

        assert len(tokens)==len(segment_ids)
        assert len(segment_ids)==len(label_ids)
        assert len(label_ids)==len(doc_ids)
        assert len(doc_ids)==len(sent_ids)

        instance = TrainingInstance_ext_tmp(
                tokens = tokens,
                segment_ids = segment_ids,
                label_ids = label_ids, 
                start_pos = start_pos,
                end_pos = end_pos,
                doc_ids = doc_ids, 
                sent_ids = sent_ids, 
                #sep_loc = sep_loc
            )
        instances.append(instance)
        
        # In case of train, learning the entire data takes too long
        # Learn by repeatedly extracting only some data as much as dup_factor
        if mode=="train":
            return instances
        
    return instances

## 1. Create Training instances

In [None]:
def create_training_instances(input_file, tokenizer, vocab_words, max_seq_length,
                              dupe_factor, short_seq_prob, masked_lm_prob,
                              max_predictions_per_seq, rng, do_lower_case, mode):
    """Create `TrainingInstance`s from raw text."""
    instances = []
    for _ in range(dupe_factor):
        instances.extend(
            create_instances_from_document(
                input_file, max_seq_length, short_seq_prob,
                masked_lm_prob, max_predictions_per_seq, rng, do_lower_case, 
                vocab_words, tokenizer, mode))

    #print("len(instances): ", len(instances))
        
    return instances

def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")

# main()

In [None]:
def main(
    # required
    input_files, 
    vocab_file, 
    outdir,
    mode,
    
    # optional
    do_lower_case=True, 
    max_seq_length=512, 
    max_predictions_per_seq = 20, 
    random_seed=12345, 
    dupe_factor = 1,
    masked_lm_prob = 0.15,
    short_seq_prob = 0.1,
    ):
    
    # bertbase
    print("vocab_file: ", vocab_file)
    tokenizer = BertTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case, max_len=max_seq_length)
    vocab_words = list(tokenizer.vocab.keys())
    print("len(vocab_words): ", len(vocab_words))

    print("max_predictions_per_seq: ", max_predictions_per_seq)
    
    randseed = random.randint(1, 1000)
    print("randseed: ", randseed)
    rng = random.Random(randseed)
    
    print("len(input_files): ", len(input_files))
    
    for i in range(0, len(input_files), 1):
        if i%100==0:
            print(str(i)+"/"+str(len(input_files)))
        instances = create_training_instances(
                        input_files[i], tokenizer, vocab_words, max_seq_length,
                        dupe_factor, short_seq_prob, masked_lm_prob,
                        max_predictions_per_seq, rng, do_lower_case, mode)
        
        filename = input_files[i].split("/")[-1]
        filename = filename.split(".")[0]+".cache"
        
        write_instance_to_example_files(instances=instances, 
                                    tokenizer=tokenizer, 
                                    vocab_words=vocab_words,
                                    max_seq_length=max_seq_length,
                                    max_predictions_per_seq=max_predictions_per_seq, 
                                    outfilename=outdir+"/"+filename)

# Generate Train data

In [None]:
import glob
import os

print("Generating Features...")

data_dirs = [
    "./data/05_sampled/train",
    "./data/05_sampled/test"
]

output_paths = [
    "bertbase_cased",
    "mbert_cased",
    "biobert"
]

vocab_paths = [
    "../otherberts/bertbase_cased",
    "../otherberts/mbert_cased",
    "../otherberts/bioBERT/biobert_v1.1_pubmed",
]

lowercase = [False, False, False]


modes = ["train", "test"]
dupe_factors = [4, 1]


assert len(output_paths)==len(vocab_paths)
assert len(vocab_paths)==len(lowercase)

for d in range(len(data_dirs)):
    mode = modes[d]
    
    for i in range(len(output_paths)):
        print("data_dirs[d]: ", data_dirs[d])
        
        out_directory = "./cache/"+str(output_paths[i])+"/"+str(mode)
        print("out_directory: ", out_directory)
        if not os.path.exists(out_directory):
            os.makedirs(out_directory)
        
        vocab_path=vocab_paths[i]
        
        input_files = glob.glob(data_dirs[d]+"/*.txt")
        input_files.sort()
        
        main(
            input_files = input_files, 
            outdir = out_directory,
            vocab_file = vocab_paths[i]+'/vocab.txt', # 한글, 영문 모두 포함한 사전        
            mode = mode,
            
            # optional
            do_lower_case = lowercase[i], 
            max_seq_length = 512, 
            dupe_factor = dupe_factors[d], 
        )
print("Done")