# Make cache (BERTs)

In [None]:
import os
import six
import time
import random
import collections

try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle

import glob

from tokenization import BertTokenizer
import tokenization as tokenization

## 4. write_instance_to_example_files

In [None]:
class TrainingInstance(object):
    """A single training instance (sentence pair)."""
    def __init__(self, input_ids, input_mask, segment_ids, is_entailed):
        self.input_ids=input_ids
        self.input_mask=input_mask
        self.segment_ids=segment_ids
        self.is_entailed=is_entailed        

In [None]:
# instance 를 exaple cinfe로 작성
def write_instance_to_example_files(instances, 
                                    tokenizer, 
                                    vocab_words,
                                    max_seq_length,
                                    outfilename):
    
    features = []
    
    for (inst_index, instance) in enumerate(instances):
        
        input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)        
        input_mask = [1] * len(input_ids)
        segment_ids = list(instance.segment_ids)
        assert len(input_ids) <= max_seq_length

        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        
        is_entailed = [0]*max_seq_length
        is_entailed[0] = 1 if instance.is_entailed else 0
        
        features.append(
            TrainingInstance(
                input_ids=input_ids, 
                input_mask=input_mask, 
                segment_ids=segment_ids, 
                is_entailed=is_entailed
            )
        )
    
    if len(features)!=0:
#         print("outfilename: ", outfilename)
        with open(outfilename, 'wb') as output:
            pickle.dump(features, output, pickle.HIGHEST_PROTOCOL)

In [None]:
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
    """Truncates a pair of sequences to a maximum sequence length."""
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_num_tokens:
            break

        trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
        assert len(trunc_tokens) >= 1

        if rng.random() < 0.5:
            del trunc_tokens[0]
        else:
            trunc_tokens.pop()

## 2. create_instances_from_document

In [None]:
class TrainingInstance_tmp(object):
    """A single training instance (sentence pair)."""
    def __init__(self, tokens, segment_ids, is_entailed):
        self.tokens = tokens
        self.segment_ids = segment_ids
        self.is_entailed = is_entailed

In [None]:
def create_instances_from_document(
    samples, max_seq_length, rng, vocab_words, tokenizer, mode):
    
    """Creates `TrainingInstance`s for a single document."""
    instances = []
    
    max_seq_length = max_seq_length-3 # [CLS], [SEP], [SEP]
    
    for i in range(len(samples)):
        record = samples[i]
        
        label = record[3]
        is_entailed = True if label=="1" else False 
        
        tokens_a = record[4]
        tokens_b = record[5]
        
        if mode=="train":
            truncate_seq_pair(tokens_a, tokens_b, max_seq_length, rng)
        elif mode=="test":
            tokens_a = tokens_a[:max_seq_length//2]
            tokens_b = tokens_b[:max_seq_length//2]

        
        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)

        tokens.append("[SEP]")
        segment_ids.append(0)

        for token in tokens_b:
            tokens.append(token)
            segment_ids.append(1)
        tokens.append("[SEP]")
        segment_ids.append(1)
        
        instance = TrainingInstance_tmp(
            tokens=tokens,
            segment_ids=segment_ids,
            is_entailed=is_entailed)
        instances.append(instance)
        
        

    return instances

## 1. Create Training instances

In [None]:
def create_training_instances(samples, tokenizer, vocab_words, max_seq_length, rng, mode):
    """Create `TrainingInstance`s from raw text."""
    
    instances = []
    instances.extend(
        create_instances_from_document(samples, max_seq_length, rng, vocab_words, tokenizer, mode))
    
    rng.shuffle(instances)
    
    return instances


def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")

In [None]:
def read_dataset(targetfile, tokenizer, version, startline, endline, do_lower_case):
    f = open(targetfile, "r")
    lines = f.readlines()
    
    lines = lines[startline:endline]
    
    stacks = []
    for l in range(len(lines)):
        if l%100==0:
            print(l, "/", len(lines))
        
        try:
            line = lines[l].split("\t")
            if len(line)<4:
                continue
        
            cate = line[0]   # data category
            ptnum1 = line[1] # ptnum1
            ptnum2 = line[2] # ptnum2
            label = line[3]  # label
            
        except:
            print("line: ", line)
            
        if do_lower_case==True:
            content1 = tokenizer.tokenize(line[4].lower())
            content2 = tokenizer.tokenize(line[5].lower())
        
        else:
            content1 = tokenizer.tokenize(line[4])
            content2 = tokenizer.tokenize(line[5])
            
        record = []
        record.append(cate)
        record.append(ptnum1)
        record.append(ptnum2)
        record.append(label)
        record.append(content1)
        record.append(content2)
        
        stacks.append(record)
    return stacks

# main()

In [None]:
def main(
    # required
    input_file, 
    vocab_file,
    version,
    outdir,
    mode,
    
    # optional
    do_lower_case=True, 
    max_seq_length=512, 
    random_seed=12345, 
    ):
    
    print("vocab_file: ", vocab_file)
    tokenizer = BertTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case, max_len=max_seq_length)
    vocab_words = list(tokenizer.vocab.keys())
    print("len(vocab_words): ", len(vocab_words))

    
    randseed = random.randint(1, 1000)
    print("randseed: ", randseed)
    rng = random.Random(randseed)
    
    f = open(input_file, "r")
    lines = f.readlines()
    
    print("generating...")
    print("len(lines): ", len(lines))
    jump = 1000
    for d in range(0, len(lines), jump):
        s = d
        e = d+jump
        if e>len(lines):
            e = len(lines)
        print(s, "~", e, "/", len(lines))
        samples = read_dataset(input_file, tokenizer, version, startline=s, endline=e, do_lower_case=do_lower_case)
        
        
        filename = "".join(["0"]*(8-len(str(s))))+str(s)+".cache"
        
        instances = create_training_instances(samples, tokenizer, vocab_words, max_seq_length, rng, mode)
        
        write_instance_to_example_files(
                        instances=instances, 
                        tokenizer=tokenizer, 
                        vocab_words=vocab_words,
                        max_seq_length=max_seq_length,
                        outfilename=outdir+"/"+filename)


# Generate data

In [None]:
import glob
import os
if not os.path.exists("./cache"):
    os.makedirs("./cache")
    
print("Generating Features...")

versions = [
    "bertbase_cased",
    "mbert_cased", 
    "biobert"
]

datapath = ["./data/train.txt", "./data/test.txt"]

vocab_paths = [
        "../otherberts/bertbase_cased",
        "../otherberts/mbert_cased",
        "../otherberts/bioBERT/biobert_v1.1_pubmed",    
]

do_lower_cases = [False, False, False]

assert len(versions)==len(vocab_paths)
assert len(versions)==len(do_lower_cases)

# folder loop
for t in range(len(versions)):
    print("versions: ", versions[t])
    out_directory = "./cache/"+str(versions[t])+"/"    
    if not os.path.exists(out_directory):
        os.makedirs(out_directory)

    # sample loop
    for s in range(len(datapath)):
        print("datapath[s]: ", datapath[s])
        datapathname = datapath[s].split("/")[-1].split(".")[0]
        out_directory = "./cache/"+str(versions[t])+"/"+str(datapathname)
        if not os.path.exists(out_directory):
            os.makedirs(out_directory)
        
        if "train" in datapath[s]:
            mode="train"
        else:
            mode="test"
            
        main(
            # required
            input_file = datapath[s], 
            outdir = out_directory,
            
            version = versions[t],

            vocab_file = vocab_paths[t]+'/vocab.txt',
            mode=mode,
            
            # optional
            do_lower_case = do_lower_cases[t],
            max_seq_length = 512, # must matching max_position_embeddings in json
        )
print("finished")