# Task 3. predict department

In [None]:
import sentencepiece as spm
import os
import six
import time
import random
import collections

try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle

import glob

from tokenization import BertTokenizer
import tokenization as tokenization

## 4. write_instance_to_example_files

In [None]:
class TrainingInstance(object):
    """A single training instance (sentence pair)."""
    def __init__(self, input_ids, input_mask, segment_ids, labels):
        self.input_ids=input_ids
        self.input_mask=input_mask
        self.segment_ids=segment_ids
        self.labels=labels

In [None]:
def write_instance_to_example_files(instances, 
                                    tokenizer, 
                                    vocab_words,
                                    max_seq_length,
                                    outfilename):
    
    features = []
    
    for (inst_index, instance) in enumerate(instances):
        
        input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)        
        input_mask = [1] * len(input_ids)
        segment_ids = list(instance.segment_ids)
        assert len(input_ids) <= max_seq_length

        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        
        labels = [-1]*max_seq_length
        labels[0] = instance.labels        
        features.append(
            TrainingInstance(
                input_ids=input_ids, 
                input_mask=input_mask, 
                segment_ids=segment_ids, 
                labels=labels
            )
        )
    
    if len(features)!=0:
        with open(outfilename, 'wb') as output:
            pickle.dump(features, output, pickle.HIGHEST_PROTOCOL)

In [None]:
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
    """Truncates a pair of sequences to a maximum sequence length."""
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_num_tokens:
            break

        trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
        assert len(trunc_tokens) >= 1

        if rng.random() < 0.5:
            del trunc_tokens[0]
        else:
            trunc_tokens.pop()

In [None]:
def truncate_seq_single(tokens_a, max_num_tokens, rng):
    """Truncates a pair of sequences to a maximum sequence length."""
    while True:
        total_length = len(tokens_a)
        if total_length <= max_num_tokens:
            break
            
        trunc_tokens = tokens_a
        assert len(trunc_tokens) >= 1

        if rng.random() < 0.5:
            del trunc_tokens[0]
        else:
            trunc_tokens.pop()

## 2. create_instances_from_document

In [None]:
class TrainingInstance_tmp(object):
    """A single training instance (sentence pair)."""
    def __init__(self, tokens, segment_ids, labels):
        self.tokens = tokens
        self.segment_ids = segment_ids
        self.labels=labels

In [None]:
def create_instances_from_document(
    samples, max_seq_length, rng, vocab_words, tokenizer):
    
    """Creates `TrainingInstance`s for a single document."""
    instances = []
    
    max_num_tokens = max_seq_length-2 # [CLS], [SEP]
    
    for i in range(len(samples)):
        record = samples[i]
        
        labels = record[1]
        labels = int(labels)
        
        tokens_a = record[2]
        
        truncate_seq_single(tokens_a, max_num_tokens, rng)
        
        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        
        tokens.append("[SEP]")
        segment_ids.append(0)
        
        instance = TrainingInstance_tmp(
            tokens=tokens,
            segment_ids=segment_ids,
            labels=labels)
        instances.append(instance)
        

    return instances

## 1. Create Training instances

In [None]:
def create_training_instances(samples, tokenizer, vocab_words, max_seq_length, rng):
    """Create `TrainingInstance`s from raw text."""
    
    instances = []
    instances.extend(
        create_instances_from_document(samples, max_seq_length, rng, vocab_words, tokenizer))
            
    print("len(instances) 최종: ", len(instances))
    rng.shuffle(instances)
    
    return instances

    

def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")

In [None]:
def read_dataset(targetfile, tokenizer, version, startline, endline, do_lower_case):
    f = open(targetfile, "r")
    lines = f.readlines()
    
    lines = lines[startline:endline]
    
    stacks = []
    for l in range(len(lines)):
        if l%100==0:
            print(l, "/", len(lines))
        
        try:
            line = lines[l].strip("\n")
            line = line.split("\t")
            if len(line)<3:
                continue
        
            label_text = line[0]   # data label text
            label_int  = line[1] # data label int
            content    = line[2] # text
            
        except:
            print("line: ", line)
            
        #content = mecab_sp_tokenizer.tokenize_mecab_sentencepiece(line[2].lower())
        if do_lower_case==True:
            content = tokenizer.tokenize(line[2].lower())
        
        else:
            content = tokenizer.tokenize(line[2])
        
        
        record = []
        record.append(label_text)
        record.append(label_int)
        record.append(content)
        
        stacks.append(record)
    return stacks

# main()

In [None]:
def main(
    # required
    input_file, 
    vocab_file,
    version,
    outdir,
    
    # optional
    do_lower_case=True, 
    max_seq_length=512, 
    random_seed=12345, 
    ):
    
    print("vocab_file: ", vocab_file)
    tokenizer = BertTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case, max_len=max_seq_length)
    vocab_words = list(tokenizer.vocab.keys())
    print("len(vocab_words): ", len(vocab_words))
    
    
    randseed = random.randint(1, 1000)
    print("randseed: ", randseed)
    rng = random.Random(randseed)
    
    f = open(input_file, "r")
    lines = f.readlines()
    
    print("generating...")
    print("len(lines): ", len(lines))
    jump = 1000
    for d in range(0, len(lines), jump):
        s = d
        e = d+jump
        if e>len(lines):
            e = len(lines)
        print(s, "~", e, "/", len(lines))
        samples = read_dataset(input_file, tokenizer, version, startline=s, endline=e, do_lower_case=do_lower_case)
        
        
        filename = "".join(["0"]*(8-len(str(s))))+str(s)+".cache"
        
        instances = create_training_instances(samples, tokenizer, vocab_words, max_seq_length, rng)
        
        write_instance_to_example_files(
                        instances=instances, 
                        tokenizer=tokenizer,
                        vocab_words=vocab_words,
                        max_seq_length=max_seq_length,
                        outfilename=outdir+"/"+filename)

# Generate data

In [None]:
import glob
import os
if not os.path.exists("./cache"):
    os.makedirs("./cache")
    
print("Generating Features...")

datapath = ["./data/train.txt", "./data/test.txt"]

versions = [
            "bertbase_cased",
            "biobert",
            "mbert_cased",   
]

vocab_paths = [
        "../otherberts/bertbase_cased",
        "../otherberts/bioBERT/biobert_v1.1_pubmed",
        "../otherberts/mbert_cased"
]

do_lower_cases = [False, False, False]

assert len(versions)==len(vocab_paths)
assert len(versions)==len(do_lower_cases)


# folder loop
for t in range(len(versions)):
    print("versions: ", versions[t])
    out_directory = "./cache/"+str(versions[t])+"/"    
    if not os.path.exists(out_directory):
        os.makedirs(out_directory)

    # sample loop
    for s in range(len(datapath)):
        print("datapath[s]: ", datapath[s])
        datapathname = datapath[s].split("/")[-1].split(".")[0]
        out_directory = "./cache/"+str(versions[t])+"/"+str(datapathname)
        if not os.path.exists(out_directory):
            os.makedirs(out_directory)

        main(
            # required
            input_file = datapath[s], 
            outdir = out_directory,
            
            version = versions[t],

            vocab_file = vocab_paths[t]+'/vocab.txt', 

            # optional
            do_lower_case = do_lower_cases[t], 
            max_seq_length = 512,
        )
print("finished")