# Make cache

In [None]:
# pip install sentencepiece
import sentencepiece as spm
import os
import six
import time
import random
import collections

try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle

import glob

import pandas as pd

from tokenization import BertTokenizer
import tokenization as tokenization

## 4. write_instance_to_example_files

In [None]:
class TrainingInstance_ext(object):
    """A single training instance (sentence pair)."""
    def __init__(self, input_ids, input_mask, segment_ids, label_ids, mask_location):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.mask_location = mask_location


In [None]:
def write_instance_to_example_files(instances, 
                                    tokenizer, 
                                    vocab_words,
                                    max_seq_length,
                                    max_predictions_per_seq, 
                                    outfilename):
    
    features = []
        
    pad_id = vocab_words.w_to_i["[PAD]"]
    
    for (inst_index, instance) in enumerate(instances):
        input_ids = []
        for l in range(len(instance.tokens)):
            tokentmp = instance.tokens[l]
            input_id = vocab_words.w_to_i[tokentmp]
            input_ids.append(input_id)
    
        input_mask = [1] * len(input_ids)
        segment_ids = list(instance.segment_ids)
        
        label_ids = list(instance.label_ids) 
        mask_location = list(instance.mask_location) 
     
        if segment_ids[-1]==0:
            seg_id=1
        else:
            seg_id=0
        
        assert len(input_ids) <= max_seq_length

        while len(input_ids) < max_seq_length:
            input_ids.append(pad_id)
            input_mask.append(0)
            segment_ids.append(seg_id)
            label_ids.append(-1)
            mask_location.append(0)
            
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(mask_location) == max_seq_length
        
        features.append(
            TrainingInstance_ext(
                input_ids = input_ids,
                input_mask = input_mask, 
                segment_ids = segment_ids,
                label_ids = label_ids,
                mask_location = mask_location,
            )
        )
    
    if len(features)!=0:
#         print("outfilename: ", outfilename)
        with open(outfilename, 'wb') as output:
            pickle.dump(features, output, pickle.HIGHEST_PROTOCOL)

## 2. create_instances_from_document

In [None]:
class TrainingInstance_ext_tmp(object):
    """A single training instance (sentence pair)."""
    def __init__(self, tokens, segment_ids, label_ids, mask_location):
        self.tokens = tokens
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.mask_location = mask_location


In [None]:
def read_samples(targetfile, tokenizer, do_lower_case, label_map, vocab_words):
    labels = []
    samples = []
    
    f = open(targetfile, "r")
    lines = f.readlines()
    f.close()
    for l in range(len(lines)):
        line = lines[l].strip("\n")
        filename_docidx = line.split("\t")[0]
        locinfo         = line.split("\t")[1]
        entinfo         = line.split("\t")[2]
        department      = line.split("\t")[3]
        content         = line.split("\t")[4]
        
        sample_tmp = []
        contents = content.split("[MASK]")
        for c in range(len(contents)):
            content_tmp = contents[c]
            if do_lower_case==True:
                content_tmp = content_tmp.lower()
            
            token_ids_tmp = tokenizer.Encode(content_tmp) # kobert
            content_tmp = []
            for t in range(len(token_ids_tmp)):
                token = vocab_words.i_to_w[token_ids_tmp[t]]
                content_tmp.append(token)
            
            sample_tmp = sample_tmp + content_tmp +["[MASK]"]
        
        sample_tmp = sample_tmp[:-1] # 마지막에 붙은 [MASK] 제거
        samples.append(sample_tmp)
        
        label = label_map[entinfo]
        labels.append(label)        
    
    return samples, labels

def read_labels(label_path="./data/labels.txt"):
    label_map = {}
    file = open(label_path, "r")
    lines = file.readlines()
    file.close()
    
    for l in range(len(lines)):
        line = lines[l].strip()
        label_map[line.replace(",", "/")] = l
    
    return label_map


In [None]:
def truncate_seq_sides(tokens_a, tokens_b, max_num_tokens, rng):
    """Truncates a pair of sequences to a maximum sequence length."""
    if len(tokens_a)+len(tokens_b) <= max_num_tokens:
        return tokens_a, tokens_b
    
    while True:
        if len(tokens_a)>len(tokens_b):
            del tokens_a[0]
        else:
            tokens_b.pop()
        
        if len(tokens_a)+len(tokens_b)<=max_num_tokens:
            return tokens_a, tokens_b

def create_instances_from_document(input_file, max_seq_length, short_seq_prob,
    masked_lm_prob, max_predictions_per_seq, rng, do_lower_case, vocab_words, tokenizer, mode):
    
    instances = []
    ########### tokenize all documents ###########
    label_map = read_labels(label_path="./data/labels.txt")
    samples, labels = read_samples(input_file, tokenizer, do_lower_case, label_map, vocab_words)
    
    for p in range(len(samples)):
        tokens = samples[p]
        label = labels[p]
        
        mask_idx = 0
        for t in range(len(tokens)):
            if tokens[t]=="[MASK]" or tokens[t]=="[mask]":
                mask_idx = t
                break
        
        tokens_left  = tokens[:mask_idx]
        tokens_right = tokens[mask_idx+1:]
        tokens_left, tokens_right = truncate_seq_sides(tokens_left, tokens_right, max_seq_length-3, rng) # [CLS], [MASK], [SEP]
        tokens = []
        tokens = tokens_left+["[MASK]"]+tokens_right
        
        tokens = ["[CLS]"]+tokens+["[SEP]"]
        segment_ids = [0]*len(tokens)
        mask_location = [0]*len(tokens)
        label_ids  = [-1]*len(tokens)
        
        if len(tokens)>max_seq_length:
            print("len(tokens): ", len(tokens))
            print("tokens: ", tokens)
            print("label_ids: ", label_ids)
        
        for t in range(len(tokens)):
            if tokens[t]=="[MASK]" or tokens[t]=="[mask]":
                mask_location[t]=1
                label_ids[t]=label
                mask_idx = t
                break
        
        assert sum(mask_location)==1
        assert len(tokens) <= max_seq_length
        assert len(tokens)==len(segment_ids)
        assert len(segment_ids)==len(label_ids)

        instance = TrainingInstance_ext_tmp(
                tokens = tokens,
                segment_ids = segment_ids,
                label_ids = label_ids, 
                mask_location = mask_location,
            )
        instances.append(instance)
        
    return instances

## 1. Create Training instances

In [None]:
def create_training_instances(input_file, tokenizer, vocab_words, max_seq_length,
                              dupe_factor, short_seq_prob, masked_lm_prob,
                              max_predictions_per_seq, rng, do_lower_case, mode):
    """Create `TrainingInstance`s from raw text."""
    instances = []
    #for _ in range(dupe_factor):
    instances.extend(
        create_instances_from_document(
            input_file, max_seq_length, short_seq_prob,
            masked_lm_prob, max_predictions_per_seq, rng, do_lower_case, 
            vocab_words, tokenizer, mode))
        
    return instances

def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")

# main()

In [None]:
def main(
    # required
    input_files, 
    vocab_file, 
    outdir,
    mode, 
    spmodel,
    
    # optional
    do_lower_case=True, 
    max_seq_length=512, 
    max_predictions_per_seq = 20, 
    random_seed=12345, 
    dupe_factor = 1,
    masked_lm_prob = 0.15,
    short_seq_prob = 0.1,
    ):
    
    
    # vocab_words
    class Vocab_words(object):
        def __init__(self, vocab_file):
            self.i_to_w = {}
            self.w_to_i = {}
            self.getvocab(vocab_file)
    #         print(self.w_to_i)

        def getvocab(self, vocab_file):
            f = open(vocab_file, 'r')
            lines = f.readlines()
            for l in range(len(lines)):
                term = lines[l].strip("\n")
                term = convert_to_unicode(term)
                self.i_to_w[int(l)] = term
                self.w_to_i[term] = int(l)
        
    vocab_words = Vocab_words(vocab_file)
    
    # sptokenizer
    tokenizer = spm.SentencePieceProcessor()
    tokenizer.load(spmodel)
    
    print("max_predictions_per_seq: ", max_predictions_per_seq)
    
    randseed = random.randint(1, 1000)
    print("randseed: ", randseed)
    rng = random.Random(randseed)
    
    print("len(input_files): ", len(input_files))
    
    for i in range(0, len(input_files), 1):
        if i%100==0:
            print(str(i)+"/"+str(len(input_files)))
        instances = create_training_instances(
                        input_files[i], tokenizer, vocab_words, max_seq_length,
                        dupe_factor, short_seq_prob, masked_lm_prob,
                        max_predictions_per_seq, rng, do_lower_case, mode)
        
        filename = input_files[i].split("/")[-1]
        filename = filename.split(".")[0]+".cache"
        
        write_instance_to_example_files(instances=instances, 
                                    tokenizer=tokenizer, 
                                    vocab_words=vocab_words,
                                    max_seq_length=max_seq_length,
                                    max_predictions_per_seq=max_predictions_per_seq, 
                                    outfilename=outdir+"/"+filename)


# Generate Train data

In [None]:
import glob
import os

In [None]:
print("Generating Features...")

data_dirs = [
    "./data/05_samples/train",
    "./data/05_samples/test",
]

output_paths = [
    "kobert"
]
vocab_paths = [
    "../otherberts/KoBERT/models"
]

lowercase = [False]

for d in range(len(data_dirs)):
    mode = data_dirs[d].split("/")[-1]
    
    for i in range(len(output_paths)):
        print("data_dirs[d]: ", data_dirs[d])
        
        out_directory = "./cache/"+str(output_paths[i])+"/"+str(mode)
        print("out_directory: ", out_directory)
        if not os.path.exists(out_directory):
            os.makedirs(out_directory)
        
        vocab_path=vocab_paths[i]
        
        input_files = glob.glob(data_dirs[d]+".txt")
        input_files.sort()
        
        main(
            input_files = input_files, 
            outdir = out_directory,
            mode = mode,
            
            #sentence piece
            #https://skt-lsl-nlp-model.s3.amazonaws.com/KoBERT/tokenizers/kobert_news_wiki_ko_cased-1087f8699e.spiece
            vocab_file = vocab_paths[i]+'/vocab.txt',
            spmodel = vocab_paths[i]+'/spiece.model',
            
            # optional
            do_lower_case = lowercase[i],
            max_seq_length = 512, 
        )
print("Done")