# Task 4. Reading Comprehension (find Assessment section) eval

In [None]:
from __future__ import absolute_import, division, print_function

import argparse
import csv
import json
import logging
import os
import random
import sys

import numpy as np
import torch
import torch.nn.functional as F

# mult processing
# import torch.multiprocessing as mp

# for horovod distributed training
# import torch.utils.data.distributed

from transformers import (WEIGHTS_NAME, AdamW, BertConfig,
                          BertForTokenClassification, BertTokenizer,
                          get_linear_schedule_with_warmup, 
                          BertPreTrainedModel, BertModel) # cls classifier 를 위해 로드

from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss


from torch import nn
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
# from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from seqeval.metrics import classification_report
from seqeval.metrics import sequence_labeling # 따로 결과를 뽑아내고 싶음

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)
import codecs

try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle
    
import glob
import time

import easydict

import six
import tqdm
from tqdm import tqdm, trange
import collections

from seqeval.metrics import classification_report

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)


import re


# pip install pytest-shutil
import shutil

In [None]:
class TrainingInstance_ext(object):
    """A single training instance (sentence pair)."""
    def __init__(self, input_ids, input_mask, segment_ids, label_ids, doc_ids, sent_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_ids = label_ids
        self.doc_ids = doc_ids
        self.sent_ids = sent_ids

## Position classifier

In [None]:
class PositionClassifier(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config, add_pooling_layer=False)
        #self.qa_outputs = nn.Linear(config.hidden_size, 2)
        self.start_outputs = nn.Linear(config.hidden_size, 2)
        self.end_outputs = nn.Linear(config.hidden_size, 2)
        
        self.init_weights()

    def forward(self, input_ids, label_ids, start_pos, end_pos, doc_ids, sent_ids, 
                attention_mask=None, token_type_ids=None,  # 220315 segment_id
                position_ids=None, head_mask=None, inputs_embeds=None,
                output_attentions=None, output_hidden_states=None,
                return_dict=None):
        
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
            sequence are not taken into account for computing the loss.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        start_logits = self.start_outputs(sequence_output).split(1, dim=-1)[1]
        end_logits   = self.end_outputs(sequence_output).split(1, dim=-1)[1]
        
        # flatten logitis
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()
        
        logits_s_flat = start_logits.view(start_logits.size()[0]*start_logits.size()[1])
        logits_e_flat = end_logits.view(end_logits.size()[0]*end_logits.size()[1])
        
        # flatten labels
        label_hot_flat_start = start_pos.view(start_pos.size()[0]*start_pos.size()[1])
        label_hot_flat_end = end_pos.view(end_pos.size()[0]*end_pos.size()[1])
        
        # loss of start
        pred_score_s = nn.functional.log_softmax(logits_s_flat, dim=-1)
        numerator_s = (-1)* torch.sum(pred_score_s * label_hot_flat_start)
        denominator_s = torch.sum(label_hot_flat_start) + 1e-5
        start_loss = numerator_s / denominator_s
        
        # loss of end
        pred_score_e = nn.functional.log_softmax(logits_e_flat, dim=-1)
        numerator_e = (-1)* torch.sum(pred_score_e * label_hot_flat_end) 
        denominator_e = torch.sum(label_hot_flat_end) + 1e-5
        end_loss = numerator_e / denominator_e
        
        total_loss = (start_loss + end_loss)/2
        
        # probs
        probs_start = start_logits.detach().cpu()
        probs_start = nn.functional.log_softmax(probs_start, dim=-1)
        probs_start = torch.exp(probs_start)
        
        probs_end = end_logits.detach().cpu()
        probs_end = nn.functional.log_softmax(probs_end, dim=-1)
        probs_end = torch.exp(probs_end)
        
        return total_loss, probs_start, probs_end


## Evaluate

In [None]:
def evaluate(args):
    ############################################################################
    ##              Multi-GPUs, Distributed settings
    ############################################################################ 
    if args.server_ip and args.server_port:
    # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()
    
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
    
    print("n_gpu: ", n_gpu)
    
    
    ############################################################################
    ##              batch size, gradient_accumulation_steps
    ############################################################################ 
    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))
    args.eval_batch_size = args.eval_batch_size // args.gradient_accumulation_steps
    print("args.eval_batch_size: ", args.eval_batch_size)
    
    ###############################################################
    #           Functions for Dataload (load cache data)
    ###############################################################
    def get_file_arrays(data_path):
        data_path = data_path.split(",")
        print("data_path: ", data_path)
        
        cache_files_all = []
        for i in range(len(data_path)):
            path = data_path[i].strip()
            files = glob.glob(path)
            cache_files_all = cache_files_all + files
        cache_files_all.sort()
        
        print("Got " + str(len(cache_files_all)) + " cache_files")
        
        return cache_files_all
    
    def read_files(cache_files_pieces):
        eval_features = []
        for c in range(len(cache_files_pieces)):
            with open(cache_files_pieces[c], 'rb') as input:
                eval_features = eval_features + pickle.load(input)
            print("cache_files_pieces[c]: ", cache_files_pieces[c])
        print("len(eval_features): ", len(eval_features))        
        
        # 학습 샘플의 개수
        num_eval_examples = len(eval_features)
        
        #input_ids, input_mask, segment_ids, label_ids, eval_pos
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_ids for f in eval_features], dtype=torch.long)
        all_start_pos = torch.tensor([f.start_pos for f in eval_features], dtype=torch.long)
        all_end_pos = torch.tensor([f.end_pos for f in eval_features], dtype=torch.long)
        all_doc_ids = torch.tensor([f.doc_ids for f in eval_features], dtype=torch.long)
        all_sent_ids = torch.tensor([f.sent_ids for f in eval_features], dtype=torch.long)
        
        all_input_ids = torch.squeeze(all_input_ids)
        all_input_mask = torch.squeeze(all_input_mask)
        all_segment_ids = torch.squeeze(all_segment_ids)
        all_label_ids = torch.squeeze(all_label_ids)
        all_start_pos = torch.squeeze(all_start_pos)
        all_end_pos = torch.squeeze(all_end_pos)
        all_doc_ids = torch.squeeze(all_doc_ids)
        all_sent_ids = torch.squeeze(all_sent_ids)
        
        eval_data = TensorDataset(
            all_input_ids, all_input_mask, all_segment_ids, all_label_ids, 
            all_start_pos, all_end_pos, all_doc_ids, all_sent_ids
        )
        
        if args.local_rank == -1:
            eval_sampler = SequentialSampler(eval_data)
        else:
            eval_sampler = DistributedSampler(eval_data)
        
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
        
        
        return eval_dataloader, num_eval_examples
    
    def read_file(cache_file):
        eval_features = []
        with open(cache_file, 'rb') as input:
            eval_features = eval_features + pickle.load(input)
        
        num_eval_examples = len(eval_features)
        
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_ids for f in eval_features], dtype=torch.long)
        all_start_pos = torch.tensor([f.start_pos for f in eval_features], dtype=torch.long)
        all_end_pos = torch.tensor([f.end_pos for f in eval_features], dtype=torch.long)
        all_doc_ids = torch.tensor([f.doc_ids for f in eval_features], dtype=torch.long)
        all_sent_ids = torch.tensor([f.sent_ids for f in eval_features], dtype=torch.long)
        
        if all_input_ids.size()[0]>1:
            all_input_ids = torch.squeeze(all_input_ids)
            all_input_mask = torch.squeeze(all_input_mask)
            all_segment_ids = torch.squeeze(all_segment_ids)
            all_label_ids = torch.squeeze(all_label_ids)
            all_start_pos = torch.squeeze(all_start_pos)
            all_end_pos = torch.squeeze(all_end_pos)
            all_doc_ids = torch.squeeze(all_doc_ids)
            all_sent_ids = torch.squeeze(all_sent_ids)
        
        eval_data = TensorDataset(
            all_input_ids, all_input_mask, all_segment_ids, all_label_ids, 
            all_start_pos, all_end_pos, all_doc_ids, all_sent_ids
        )
        
        if args.local_rank == -1:
            eval_sampler = SequentialSampler(eval_data)
        else:
            eval_sampler = DistributedSampler(eval_data)
        
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
        
        return eval_dataloader, num_eval_examples
    
    ############################################################################
    ##                    Load Training data
    ############################################################################
    label_list = ["0", "1"]
    num_labels = len(label_list) 
    
    ###############################################################
    #            Dataload (load cache data)
    ###############################################################
    cache_files = get_file_arrays(args.data_dir)
    print("cache_files: ", cache_files)
    

    ############################################################################
    ##                      Prepare model
    ############################################################################
    print("loading weights from checkpoint (", args.checkpoint, ")")
    print("args.checkpoint: ", args.checkpoint)
    config = BertConfig.from_pretrained(args.checkpoint, num_labels=num_labels)
    model = PositionClassifier.from_pretrained(args.checkpoint,
              from_tf = False,
              config = config)
    print("loaded weights from checkpoint (", args.checkpoint, ")")
    
    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
    print("device: ", device)
    
    model.to(device)    
    
    ############################################################################
    ##              모델 저장 함수
    ############################################################################
    def save_model(args, global_step):
        if not os.path.exists(args.output_dir+"/"+str(global_step)):
            os.makedirs(args.output_dir+"/"+str(global_step))

        # save model
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        model_to_save.save_pretrained(args.output_dir+"/"+str(global_step))
    
    ############################################################################
    ##              Start training
    ############################################################################
    global_step = -1
    nb_tr_steps = 0
    tr_loss = 0
    
    args.save_file_limit
    min_train_loss = 1000000
    save_checkpoints = []
    save_loss = []
    
    
    model.eval()
    
    
    for c in range(len(cache_files)):
        filename = cache_files[c].split("/")[-1].split(".")[0]
        #print("filename: ", filename)       
        #eval_dataloader, num_eval_examples = read_files(cache_files[c])
        eval_dataloader, num_eval_examples = read_file(cache_files[c])
        
        out_doc_ids = []
        out_sent_ids = []
        out_tokens = []
        out_preds = []
        out_labels = []
        out_logits_start = []
        out_logits_end = []
        
        for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration")):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, start_pos, end_pos, doc_ids, sent_ids = batch

            loss, start_logits, end_logits = model(input_ids=input_ids, label_ids=label_ids, 
                                start_pos=start_pos, end_pos=end_pos, 
                                doc_ids=doc_ids, sent_ids=sent_ids,
                                attention_mask=input_mask, token_type_ids=segment_ids,
                                position_ids=None, head_mask=None, inputs_embeds=None,
                                output_attentions=None, output_hidden_states=None,
                                return_dict=None)
            
            #print("probs: ", probs)
            input_ids  = input_ids.detach().cpu()
            input_mask = input_mask.detach().cpu()
            doc_ids    = doc_ids.detach().cpu()
            sent_ids   = sent_ids.detach().cpu()
            label_ids  = label_ids.detach().cpu()
            start_logits = start_logits.detach().cpu()
            end_logits = end_logits.detach().cpu()
            
            # text gathering
            for b in range(len(input_ids)):
                # assessment ranges
                preds_start = torch.argmax(start_logits[b], dim=-1).item()
                preds_end = torch.argmax(end_logits[b], dim=-1).item()

                for s in range(len(input_ids[b])):
                    if preds_start<=s and s<=preds_end:
                        asmt_pred = 1
                    else:
                        asmt_pred = 0
                        
                    if input_mask[b][s]==0:
                        break
                        
                    # exclude cls
                    if s==0:
                        continue
                    
                    out_tokens.append(str(input_ids[b][s].item()))
                    out_doc_ids.append(str(doc_ids[b][s].item()))
                    out_sent_ids.append(str(sent_ids[b][s].item()))
                    out_labels.append(str(label_ids[b][s].item()))
                    out_preds.append(str(asmt_pred))
                    out_logits_start.append(str(start_logits[b][s].item()))
                    out_logits_end.append(str(end_logits[b][s].item()))
                    
        
            # out predictions
            assert len(out_preds)==len(out_tokens)
            assert len(out_preds)==len(out_doc_ids)
            assert len(out_preds)==len(out_sent_ids)
            assert len(out_preds)==len(out_labels)
            assert len(out_preds)==len(out_logits_start)
            assert len(out_preds)==len(out_logits_end)


            outtext = []
            outtext.append("tokens\t"+"\t".join(out_tokens))
            outtext.append("preds\t"+"\t".join(out_preds))
            outtext.append("doc_ids\t"+"\t".join(out_doc_ids))
            outtext.append("sent_ids\t"+"\t".join(out_sent_ids))
            outtext.append("labels\t"+"\t".join(out_labels))
            outtext.append("out_logits_start\t"+"\t".join(out_logits_start))
            outtext.append("out_logits_end\t"+"\t".join(out_logits_end))
            
            test_mode = args.data_dir.split("/")[-2]
            
            if not os.path.exists(args.output_dir+"/"+str(test_mode)+"_pred"):
                os.makedirs(args.output_dir+"/"+str(test_mode)+"_pred")
            
            file = open(args.output_dir+"/"+str(test_mode)+"_pred"+"/"+str(filename)+".txt", "w")
            file.write("\n".join(outtext))
            file.close()

    

In [None]:
def main(args):
    parser = argparse.ArgumentParser()
    
    print("args.output_dir: ", args.output_dir)
    
    # check directories
    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
            
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    if args.do_train:
        train(args)
    
    if args.do_eval:
        evaluate(args)

# infer

In [None]:
data_dirs = [
    './cache/bertbase_cased/test/*.cache',
    './cache/mbert_cased/test/*.cache',
    './cache/biobert/test/*.cache',
    './cache/kobert/test/*.cache',
]

# pretrained models
checkpoints = [
    "./finetuned/ver9.1.4_521121_epoch2/4597",
    "./finetuned/ver8.1.4_1142642_epoch2/4597",
    './finetuned/ver11.1.4_521079_epoch2/4597',
    './finetuned/ver12.1.4_407013_epoch2/4597',
]

assert len(data_dirs)==len(checkpoints)

for d in range(len(data_dirs)):
    output_dir = "./finetuned/"+str(checkpoints[d].split("/")[-2])
    print("output_dir: ", output_dir)
    
    args = easydict.EasyDict({
        'data_dir':data_dirs[d],
        'eval_batch_size':5,             # train_batch_size/gradient_accumulation_steps
        'gradient_accumulation_steps':1,  # become update step
        'save_globalstep':1,
        'checkpoint':checkpoints[d],
        'output_dir':output_dir,
        'save_step':1,
        'num_files_on_memory':1,
        
        'do_train':False,
        'do_eval':True,
        'do_lower_case':True,
        'fp16_allreduce':False,
        'seed':42,
        'no_cuda':False,
        'use_adasum':False,
        'learning_rate':3e-05,
        'warmup_proportion':1.00,
        'num_train_epochs':2,
        'save_file_limit':1,
        
        'weight_decay':0.01,
        'adam_epsilon':1e-8,
        'gradient_predivide_factor':1.0,
        'fp16':False,
        'max_grad_norm':1.0,
        'local_rank':-1,
        'fp16_opt_level':'O1',
        'loss_scale':0,
        
        'server_ip':None,
        'server_port':None,
        'cuda':True,
    })
    
    
    # configs
    if "mbert_cased" == str(data_dirs[d].split("/")[2]):
        print("multilingualbert configs")
        # multilingual bert config
        args["attention_probs_dropout_prob"] = 0.1
        args["directionality"] = "bidi"
        args["hidden_act"] = "gelu"
        args["hidden_dropout_prob"] = 0.1
        args["hidden_size"] = 768
        args["initializer_range"] = 0.02
        args["intermediate_size"] = 3072
        args["max_position_embeddings"] = 512
        args["num_attention_heads"] =  12
        args["num_hidden_layers"] = 12
        args["pooler_fc_size"] = 768
        args["pooler_num_attention_heads"] = 12
        args["pooler_num_fc_layers"] = 3
        args["pooler_size_per_head"] = 128
        args["pooler_type"] = "first_token_transform"
        args["type_vocab_size"] = 2
        args["vocab_size"] = 119547
        args["cls_id"] = 101
        args["sep_id"] = 102
        
    elif "mbert_uncased" == str(data_dirs[d].split("/")[2]):
        print("multilingualbert configs")
        # multilingual bert config
        args["attention_probs_dropout_prob"] = 0.1
        args["directionality"] = "bidi"
        args["hidden_act"] = "gelu"
        args["hidden_dropout_prob"] = 0.1
        args["hidden_size"] = 768
        args["initializer_range"] = 0.02
        args["intermediate_size"] = 3072
        args["layer_norm_eps"] = 1e-12
        args["max_position_embeddings"] = 512
        args["model_type"] = "bert"
        args["num_attention_heads"] =  12
        args["num_hidden_layers"] = 12
        args["pad_token_id"] = 0
        args["pooler_fc_size"] = 768
        args["pooler_num_attention_heads"] = 12
        args["pooler_num_fc_layers"] = 3
        args["pooler_size_per_head"] = 128
        args["pooler_type"] = "first_token_transform"
        args["type_vocab_size"] = 2
        args["vocab_size"] = 105879
        args["cls_id"] = 101
        args["sep_id"] = 102
        
    elif "bertbase_uncased" == str(data_dirs[d].split("/")[2]):
        args['vocab_size']=30522 # bert-base-uncased
        args['hidden_size']=768
        args['num_hidden_layers']=12
        args['num_attention_heads']=12
        args['hidden_act']='gelu'
        args['intermediate_size']=3072
        args['hidden_dropout_prob']=0.1
        args['attention_probs_dropout_prob']=0.1
        args['max_position_embeddings']=512
        args['type_vocab_size']=2
        args['initializer_range']=0.02
        args['layer_norm_eps']=1e-12
        args['gradient_checkpointing']=None
        args['position_embedding_type']=None
        args['use_cache']=None
        args['classifier_dropout']=None
        args["cls_id"] = 101
        args["sep_id"] = 102

    elif "bertbase_cased" == str(data_dirs[d].split("/")[2]):
        args['vocab_size']=28996 # bert-base-cased
        args['attention_probs_dropout_prob']=0.1
        args['hidden_act']='gelu'
        args['hidden_dropout_prob']=0.1
        args['hidden_size']=768
        args['initializer_range']=0.02
        args['intermediate_size']=3072
        args['layer_norm_eps']=1e-12
        args['max_position_embeddings']=512
        args['model_type']="bert"
        args['num_hidden_layers']=12
        args['num_attention_heads']=12
        args['pad_token_id']=0        
        args['type_vocab_size']=2
        args['gradient_checkpointing']=None
        args['position_embedding_type']=None
        args['use_cache']=None
        args['classifier_dropout']=None
        args["cls_id"] = 101
        args["sep_id"] = 102
        
    # configs
    elif "biobert" == str(data_dirs[d].split("/")[2]):
        # BioBERT
        print("biobert configs")
        args["attention_probs_dropout_prob"]= 0.1
        args["hidden_act"] = "gelu"
        args["hidden_dropout_prob"] = 0.1
        args["hidden_size"] = 768
        args["initializer_range"] = 0.02
        args["intermediate_size"] = 3072
        args["max_position_embeddings"] = 512
        args["num_attention_heads"] = 12
        args["num_hidden_layers"] = 12
        args["type_vocab_size"] = 2
        args["vocab_size"] = 28996
        args["cls_id"] = 101
        args["sep_id"] = 102
        
    elif "kobert" == str(data_dirs[d].split("/")[2]):
        print("kobert configs")
        args["attention_probs_dropout_prob"]= 0.1
        args["gradient_checkpointing"]= False
        args["hidden_act"]= "gelu"
        args["hidden_dropout_prob"]= 0.1
        args["hidden_size"]= 768
        args["initializer_range"]= 0.02
        args["intermediate_size"]= 3072
        args["layer_norm_eps"]= 1e-12
        args["max_position_embeddings"]= 512
        args["model_type"]= "bert"
        args["num_attention_heads"]= 12
        args["num_hidden_layers"]= 12
        args["pad_token_id"]= 1
        args["type_vocab_size"]= 2
        args["vocab_size"]= 8002
        args["author"]= "Heewon Jeon(madjakarta@gmail.com)"
        args["kobert_version"]= 1.0
        args["cls_id"] = 2
        args["sep_id"] = 3
                
    else:
        print("MY BERT")
        args['vocab_size']=30014
        args['hidden_size']=768
        args['num_hidden_layers']=12
        args['num_attention_heads']=12
        args['hidden_act']='gelu'
        args['intermediate_size']=3072
        args['hidden_dropout_prob']=0.1
        args['attention_probs_dropout_prob']=0.1
        args['max_position_embeddings']=512
        args['type_vocab_size']=2
        args['initializer_range']=0.02
        args['layer_norm_eps']=1e-12
        args['gradient_checkpointing']=None
        args['position_embedding_type']=None
        args['use_cache']=None
        args['classifier_dropout']=None
        args["cls_id"] = 4
        args["sep_id"] = 5
        
    main(args)