# Make cache

In [None]:
# pip install sentencepiece
import sentencepiece as spm
import os
import six
import time
import random
import collections

try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle

import glob

import pandas as pd
import sentencepiece as spm

## 4. write_instance_to_example_files

In [None]:
class TrainingInstance_ext(object):
    """A single training instance (sentence pair)."""
    def __init__(self, input_ids_0, input_mask_0, segment_ids_0, 
                 input_ids_1, input_mask_1, segment_ids_1, label_ids):
        
        self.input_ids_0 = input_ids_0
        self.input_mask_0 = input_mask_0
        self.segment_ids_0 = segment_ids_0
        
        self.input_ids_1 = input_ids_1
        self.input_mask_1 = input_mask_1
        self.segment_ids_1 = segment_ids_1
        
        self.label_ids = label_ids
        

In [None]:
def write_instance_to_example_files(instances, 
                                    tokenizer, 
                                    vocab_words,
                                    max_seq_length,
                                    max_predictions_per_seq, 
                                    outfilename):
    
    features = []
    pad_id = vocab_words.w_to_i["[PAD]"]
    
    for (inst_index, instance) in enumerate(instances):
        input_ids_0 = []
        for l in range(len(instance.tokens_query)):
            tokentmp = instance.tokens_query[l]
            input_id = vocab_words.w_to_i[tokentmp]
            input_ids_0.append(input_id)

        # kobert
        input_ids_1 = []
        for l in range(len(instance.tokens_doc)):
            tokentmp = instance.tokens_doc[l]
            input_id = vocab_words.w_to_i[tokentmp]
            input_ids_1.append(input_id)
        
        input_mask_0 = [1] * len(input_ids_0)
        input_mask_1 = [1] * len(input_ids_1)
        
        segment_ids_0 = list(instance.segment_ids_query)
        segment_ids_1 = list(instance.segment_ids_doc)
        
        label_ids = list(instance.label_ids) # label_id
        
        if segment_ids_0[-1]==0:
            seg_id_0=1
        else:
            seg_id_0=0
        
        if segment_ids_1[-1]==0:
            seg_id_1=1
        else:
            seg_id_1=0
            
        assert len(input_ids_0) <= max_seq_length
        assert len(input_ids_1) <= max_seq_length

        while len(input_ids_0) < max_seq_length:
            input_ids_0.append(pad_id)
            input_mask_0.append(0)
            segment_ids_0.append(seg_id_0)
        assert len(input_ids_0) == max_seq_length
        assert len(input_mask_0) == max_seq_length
        assert len(segment_ids_0) == max_seq_length
        
        while len(input_ids_1) < max_seq_length:
            input_ids_1.append(pad_id)
            input_mask_1.append(0)
            segment_ids_1.append(seg_id_1)
        assert len(input_ids_1) == max_seq_length
        assert len(input_mask_1) == max_seq_length
        assert len(segment_ids_1) == max_seq_length            
            
        while len(label_ids) < max_seq_length:
            label_ids.append(0)        # label
        assert len(label_ids) == max_seq_length

        features.append(
            TrainingInstance_ext(
                input_ids_0 = input_ids_0,
                input_mask_0 = input_mask_0, 
                segment_ids_0 = segment_ids_0,
                
                input_ids_1 = input_ids_1,
                input_mask_1 = input_mask_1, 
                segment_ids_1 = segment_ids_1,
                
                label_ids = label_ids,
            )
        )
    
    if len(features)!=0:
#         print("outfilename: ", outfilename)
        with open(outfilename, 'wb') as output:
            pickle.dump(features, output, pickle.HIGHEST_PROTOCOL)

## 2. create_instances_from_document

In [None]:
class TrainingInstance_ext_tmp(object):
    """A single training instance (sentence pair)."""
    def __init__(self, tokens_query, segment_ids_query, 
                 tokens_doc, segment_ids_doc, label_ids): 
        
        self.tokens_query = tokens_query
        self.segment_ids_query = segment_ids_query
        
        self.tokens_doc = tokens_doc
        self.segment_ids_doc = segment_ids_doc
        
        self.label_ids = label_ids
        

In [None]:
def truncate_seq(tokens_a, max_num_tokens):    
    """Truncates a pair of sequences to a maximum sequence length."""
    return tokens_a[:max_num_tokens]

In [None]:
def tokenize_contents(input_seq, tokenizer, do_lower_case, vocab_words):
    subtokens = []
    if do_lower_case==True:
        input_seq = input_seq.lower()
        input_seq = input_seq.repace("[mask]", "[MASK]")
        input_seq = input_seq.repace("[sep]", "[SEP]")
    token_ids_tmp = tokenizer.Encode(input_seq) # kobert
    
    for t in range(len(token_ids_tmp)):
        token = vocab_words.i_to_w[token_ids_tmp[t]]
        subtokens.append(token)
    
    return subtokens


def read_documents(targetfile, tokenizer, do_lower_case, vocab_words):
    f = open(targetfile, "r")
    lines = f.readlines()
    f.close()
    
    questions = []
    labels = []
    for l in range(len(lines)):
        line = lines[l].strip("\n")
        if line=="":
            continue
            
        # info line
        if "[START_QUESTION]"==line:
            quest = []
            label_q = []
        elif "[END_QUESTION]"==line:
            assert len(quest)==5
            assert len(quest)==len(label_q)
            
            questions.append(quest)
            labels.append(label_q)
            
        elif line!=0:
            filename  = line.split("\t")[0]
            doc_order = int(line.split("\t")[1])
            date      = line.split("\t")[2]
            content   = line.split("\t")[3]
            
            content_tkzd = tokenize_contents(content, tokenizer, do_lower_case, vocab_words)
            
            if doc_order==4:
                label_tmp = 1
            else:
                label_tmp = 0
                
            quest.append(content_tkzd)
            label_q.append(label_tmp)
        
    return questions, labels

In [None]:
def create_instances_from_document(input_file, max_seq_length, short_seq_prob,
    masked_lm_prob, max_predictions_per_seq, rng, do_lower_case, vocab_words, mecab_sp_tokenizer, mode):
    
    instances = []
    
    truncate_len = int(512-2)
    
    ########### tokenize all documents ###########
    questions, labels = read_documents(input_file, mecab_sp_tokenizer, do_lower_case, vocab_words)
    
    assert len(questions)==len(labels)
        
    for q in range(len(questions)):
        tokens_query = truncate_seq(questions[q][0], truncate_len)
        tokens_query[-1] = "[SEP]"

        # query
        tokens_query = ["[CLS]"] + tokens_query
        seg_id = 0
        segment_ids_query = [seg_id]*len(tokens_query)
        
        # candidates
        for d in range(1, len(questions[q])):
            tokens_doc = truncate_seq(questions[q][d], truncate_len)
            tokens_doc[-1] = "[SEP]"

            # [CLS]
            tokens_doc = ["[CLS]"] + tokens_doc
            labels_doc = labels[q][d]
            seg_id = 1
            segment_ids_doc = [seg_id]*len(tokens_doc)

            instance = TrainingInstance_ext_tmp(
                tokens_query = tokens_query,
                segment_ids_query = segment_ids_query,

                tokens_doc = tokens_doc,
                segment_ids_doc = segment_ids_doc,

                label_ids = [labels_doc], 
            )

            instances.append(instance)
    
    return instances


## 1. Create Training instances

In [None]:
def create_training_instances(input_file, tokenizer, vocab_words, max_seq_length,
                              dupe_factor, short_seq_prob, masked_lm_prob,
                              max_predictions_per_seq, rng, do_lower_case, mode):
    """Create `TrainingInstance`s from raw text."""
    instances = []
    #for _ in range(dupe_factor):
    instances.extend(
        create_instances_from_document(
            input_file, max_seq_length, short_seq_prob,
            masked_lm_prob, max_predictions_per_seq, rng, do_lower_case, 
            vocab_words, tokenizer, mode))

    #print("len(instances): ", len(instances))
        
    return instances

def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")

# main()

In [None]:
def main(
    # required
    input_files, 
    vocab_file, 
    outdir,
    mode, 
    spmodel, 
    
    # optional
    do_lower_case=True, 
    max_seq_length=512, 
    max_predictions_per_seq = 20, 
    random_seed=12345, 
    dupe_factor = 1,
    masked_lm_prob = 0.15,
    short_seq_prob = 0.1,
    ):
    
    randseed = random.randint(1, 1000)
    print("randseed: ", randseed)
    rng = random.Random(randseed)
    
    # vocab_words
    class Vocab_words(object):
        def __init__(self, vocab_file):
            self.i_to_w = {}
            self.w_to_i = {}
            self.getvocab(vocab_file)

        def getvocab(self, vocab_file):
            f = open(vocab_file, 'r')
            lines = f.readlines()
            for l in range(len(lines)):
                term = lines[l].strip("\n")
                term = convert_to_unicode(term)
                self.i_to_w[int(l)] = term
                self.w_to_i[term] = int(l)
        
    vocab_words = Vocab_words(vocab_file)
    
    # sptokenizer
    tokenizer = spm.SentencePieceProcessor()
    tokenizer.load(spmodel)
    
    print("len(input_files): ", len(input_files))
    
    for i in range(0, len(input_files), 1):
        if i%100==0:
            print(str(i)+"/"+str(len(input_files)))
        instances = create_training_instances(
                        input_files[i], tokenizer, vocab_words, max_seq_length,
                        dupe_factor, short_seq_prob, masked_lm_prob,
                        max_predictions_per_seq, rng, do_lower_case, mode)
        filename = input_files[i].split("/")[-1]
        filename = filename.split(".")[0]+".cache"
        
        write_instance_to_example_files(instances=instances, 
                                    tokenizer=tokenizer, 
                                    vocab_words=vocab_words,
                                    max_seq_length=max_seq_length,
                                    max_predictions_per_seq=max_predictions_per_seq, 
                                    outfilename=outdir+"/"+filename)

# Generate Train data

In [None]:
import glob
import os

In [None]:
print("Generating Features...")

data_dirs = [
    "./data/05_samples/train",
    "./data/05_samples/test",
]

output_paths = [
    "kobert"
]

vocab_paths = [
    "../otherberts/KoBERT/models",    
]

do_lower_cases = [False]

modes = ["train", "test"]

for d in range(len(data_dirs)):
    mode = modes[d]
    for i in range(len(output_paths)):
        print("data_dirs[d]: ", data_dirs[d])
        
        out_directory = "./cache/"+str(output_paths[i])+"/"+str(mode)
        print("out_directory: ", out_directory)
        if not os.path.exists(out_directory):
            os.makedirs(out_directory)
        
        vocab_path=vocab_paths[i]
        
        input_files = glob.glob(data_dirs[d]+"/*.txt")
        input_files.sort()
        
        main(
            input_files = input_files, 
            outdir = out_directory,
            mode=mode,
            
            #sentence piece
            #https://skt-lsl-nlp-model.s3.amazonaws.com/KoBERT/tokenizers/kobert_news_wiki_ko_cased-1087f8699e.spiece
            vocab_file = vocab_paths[i]+'/vocab.txt',
            spmodel = vocab_paths[i]+'/spiece.model',

            # optional
            do_lower_case = do_lower_cases[i],
            max_seq_length = 512,
        )
print("Done")