# NLI eval
- we refered code following github: 
- https://github.com/kamalkraj/BERT-NER

In [None]:
from __future__ import absolute_import, division, print_function

import argparse
import csv
import json
import logging
import os
import random
import sys

import numpy as np
import torch
import torch.nn.functional as F

# mult processing
#import torch.multiprocessing as mp

import torch.utils.data.distributed

from transformers import (WEIGHTS_NAME, AdamW, BertConfig,
                          BertForTokenClassification, BertTokenizer,
                          get_linear_schedule_with_warmup, 
                          BertPreTrainedModel, BertModel) # cls classifier 를 위해 로드

from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from torch import nn
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
# from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from seqeval.metrics import classification_report
from seqeval.metrics import sequence_labeling # 따로 결과를 뽑아내고 싶음

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)
import codecs

try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle
    
import glob
import time

import easydict

import six
import tqdm
from tqdm import tqdm, trange
import collections

from seqeval.metrics import classification_report

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

import re

import shutil

In [None]:
class TrainingInstance(object):
    """A single training instance (sentence pair)."""
    def __init__(self, input_ids, input_mask, segment_ids, labels, eval_pos):
        self.input_ids=input_ids
        self.input_mask=input_mask
        self.segment_ids=segment_ids 
        self.labels=labels
        self.eval_pos = eval_pos

## SO classifier

In [None]:
class SO_classifier(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.classifier = nn.Linear(config.hidden_size*2, 2)         
        self.init_weights()
        self.loss = torch.nn.BCELoss(reduction='none')
        self.actfct = torch.nn.Sigmoid()
        
    def forward(self, input_ids, label_ids, eval_pos,  # 220513 eval_pos
                attention_mask=None, token_type_ids=None,  # 220315 segment_id
                position_ids=None, head_mask=None, inputs_embeds=None,
                output_attentions=None, output_hidden_states=None,
                return_dict=None, split_num=1):
        
        
        outputs = self.bert(
            input_ids = input_ids,
            attention_mask = attention_mask,
            token_type_ids = token_type_ids,
            position_ids = position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        
        # first [SEP]: AP context
        # second~ [SEP]: SO context
        sequence_output = outputs[0]
        
        forward_pos = []        
        valid_output = []
        valid_labels = []
        
        ## prediction
        for b in range(len(eval_pos)):
            ap_point = eval_pos[b][0]            
            soap_point = []
            for e in range(1, len(eval_pos[b])):
                if eval_pos[b][e]==0:
                    break
                
                so_point = eval_pos[b][e]
                soap_point.append([ap_point.item(), so_point.item()])
                
                # ap + so context
                pooled = torch.cat([sequence_output[b][ap_point], sequence_output[b][so_point]])
                
                valid_output.append(pooled)
                
                # label
                valid_labels.append(label_ids[b][e-1].item())
            
            forward_pos.append(soap_point)
        
        valid_output = torch.stack(valid_output)
        valid_output = valid_output.to(self.device)
        
        seq_relationship_scores = self.classifier(valid_output)        
        
        ### one-hot encoding
        numerator = 0
        denominator = 1e-5
        classes = 2
        label_hot = torch.zeros(seq_relationship_scores.size())
        
        for l in range(len(label_hot)):
            label = valid_labels[l]
            label_hot[l][label] = 1
        
        label_hot = label_hot.to(self.device)
       
        # cross entropy
        prediction_scores = nn.functional.log_softmax(seq_relationship_scores, dim=-1)
        numerator = (-1)* torch.sum(prediction_scores * label_hot)
        denominator = torch.sum(label_hot) + 1e-5
        
        loss = numerator / denominator       
        
        # probabilities
        probs = prediction_scores.detach().cpu()
        
        # When finetuning, predictions are made according to the number of samples without batch distinction.
        # The prediction results of each sample must be reconstructed according to the batch.
        count_p = 0
        outtext = []
        for b in range(len(forward_pos)):
            soaploc = forward_pos[b]
            tmp_label = []
            tmp_pred = []
            tmp_props = []
            for p in range(len(soaploc)):
                tmp_label.append(valid_labels[count_p])
                
                pred = int(torch.argmax(probs[count_p], dim=-1).item())
                tmp_pred.append(pred)
                
                tmp_props.append([probs[count_p][0].item(), probs[count_p][1].item()])
                
                outtext.append("batch_"+str(b)+"\t"+str(probs[count_p])+"\t"+str(label_ids[b][p].item())+"\t"+str(pred))
                count_p = count_p + 1
                
        assert count_p==len(probs)
        
        return loss, outtext

## Evaluation

In [None]:
def evaluate(args):
    ############################################################################
    ##              Multi-GPUs, Distributed settings
    ############################################################################ 
    if args.local_rank == -1 or args.no_cuda:
        #device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        #n_gpu = torch.cuda.device_count()
        device = "cuda:0"
        n_gpu = 1
        
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
    
    
    ###############################################################
    #              Functions for Dataload (load cache data)
    ###############################################################
    # 파일 목록 획득
    def get_file_arrays(data_path):
        data_path = data_path.split(",")
        print("data_path: ", data_path)
        
        cache_files_all = []
        for i in range(len(data_path)):
            path = data_path[i].strip()
#             print("data_path[i]: ", path)
            files = glob.glob(path)
#             print("len(files): ", len(files))
            cache_files_all = cache_files_all + files
        cache_files_all.sort()
        
        print("Got " + str(len(cache_files_all)) + " cache_files")
        
        return cache_files_all
    
    # split file arrays
    def split_file_array(cache_files, num_files_on_memory):
        cache_file_groups = []
        for i in range(0, len(cache_files), num_files_on_memory):
            start = i
            if (i+num_files_on_memory) < len(cache_files):
                end = i+num_files_on_memory
            else:
                end = len(cache_files)
            cache_file_groups.append(cache_files[start:end])                
        return cache_file_groups
    
   # 파일 읽기
    def read_files(cache_files_pieces):
#         print("READING NEXT FILES OF GROUP...")
        eval_features = []
        for c in range(len(cache_files_pieces)):
            with open(cache_files_pieces[c], 'rb') as input:
                eval_features = eval_features + pickle.load(input)
            print("cache_files_pieces[c]: ", cache_files_pieces[c])
        print("len(eval_features): ", len(eval_features))        
        
        # 학습 샘플의 개수
        num_eval_examples = len(eval_features)
        
        #input_ids, input_mask, segment_ids, label_ids, eval_pos
        
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.labels for f in eval_features], dtype=torch.long)
        all_eval_pos = torch.tensor([f.eval_pos for f in eval_features], dtype=torch.long)        
        
        # 반드시 squeeeze 하여 돌릴 것 (쓸데없는 차원떄문에 모델 forward 가 안되는 문제)
        all_input_ids = torch.squeeze(all_input_ids)
        all_input_mask = torch.squeeze(all_input_mask)
        all_segment_ids = torch.squeeze(all_segment_ids)
        all_label_ids = torch.squeeze(all_label_ids)
        all_eval_pos = torch.squeeze(all_eval_pos)
        
        
        eval_data = TensorDataset(
            all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_eval_pos
        )
        
        if args.local_rank == -1:
            eval_sampler = RandomSampler(eval_data)
        else:
            eval_sampler = DistributedSampler(eval_data)
        
        print("len(eval_data): ", len(eval_data))
        print("args.eval_batch_size: ", args.eval_batch_size)
        
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
        
        print("len(eval_dataloader): ", len(eval_dataloader))
        
        return eval_dataloader, num_eval_examples
    
    
    ############################################################################
    ##                    labels
    ############################################################################
    label_list = ["0", "1"] # not extraction, extraction
    num_labels = len(label_list) # + 1 NER 에서는 [PAD] 때문에 +1 했었음
    

    ############################################################################
    ##                      Prepare model
    ############################################################################
    print("loading weights from checkpoint (", args.checkpoint, ")")
    config = BertConfig.from_pretrained(args.checkpoint, num_labels=num_labels)
    model = SO_classifier.from_pretrained(args.checkpoint,
              from_tf = False,
              config = config)
    print("loaded weights from checkpoint (", args.checkpoint, ")")
    
    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab
    print("device: ", device)
    
    model.to(device)
    
    ############################################################################
    ##              Start evalution
    ############################################################################
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    
    model.eval()
    
    
    ###############################################################
    #            Dataload (load cache data) - Train data
    ###############################################################
    cache_files = get_file_arrays(args.data_dir)
    print("cache_files: ", cache_files)
    
    true_pred = 0 # true prediction
    true_gold = 0 # true label
    
    true_positives = 0
    num_samples = 0
    
    for c in range(len(cache_files)):
        print("targetfile: ", cache_files[c])
        eval_dataloader, num_eval_examples = read_files([cache_files[c]])
        
        filename = cache_files[c].split("/")[-1]
        
        outtext = "label\tprediction\n"
        
        for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration")):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, eval_pos = batch
            
            loss, probs = model(input_ids=input_ids, label_ids=label_ids, eval_pos=eval_pos, 
                attention_mask=input_mask, token_type_ids=segment_ids,
                position_ids=None, head_mask=None, inputs_embeds=None,
                output_attentions=None, output_hidden_states=None,
                return_dict=None, split_num=args.split_num)            
            
            #print("probs: ", probs)
            outtext = outtext + "\n".join(probs)+"\n"
            
            num_samples = len(probs)
            for b in range(len(probs)):
                label = int(probs[b].split("\t")[-2])
                pred  = int(probs[b].split("\t")[-1])
                
                # prediction entailed
                if pred==0:
                    true_pred = true_pred + 1
                    
                # gold entailed
                if label==0:
                    true_gold = true_gold + 1
                
                # true positives (==Positive predictive value = precision)
                if label==pred and label==0:
                    true_positives = true_positives + 1

        # output predictions
        if not os.path.exists(args.output_dir+"/preds/"):
            os.makedirs(args.output_dir+"/preds/")
        
        file = open(args.output_dir+"/preds/"+filename[:-len(".cache")]+".txt", "w")
        file.write(outtext)
        file.close()
    
        # entailed (same patient) = 0
        # different patients = 1
        print("true_positives: ", true_positives)
        print("true_pred: ", true_pred)
        print("true_gold: ", true_gold)
        precision = float(true_positives/true_pred)
        # true positives / real_true
        recall = float(true_positives/true_gold)
        f1_score = 2 / ((1/precision) + (1/recall))
    
        
    # output score
    file = open(args.output_dir+"/sentext_f1_score.txt", "w")
    
    scoretext = "precision\trecall\tf1score\n"
    scoretext = scoretext+str(precision)+"\t"+str(recall)+"\t"+str(f1_score)
    
    file.write(scoretext)
    file.close()
    
    

In [None]:
def main(args):
    parser = argparse.ArgumentParser()
    
    print("args.output_dir: ", args.output_dir)
    
    # check directories
    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
        
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    if args.do_train:
        train(args)
    
    if args.do_eval:
        evaluate(args)

# Evaluate Test set

In [None]:
data_dirs = [
    './cache/bertbase_cased/test/*.cache',    
    './cache/mbert_uncased/test/*.cache', 
    './cache/biobert/test/*.cache',
    './cache/kobert/test/*.cache',
]

checkpoints = [
    './finetuned/ver8.1.4_1142642_epoch2/4045',  # bert cased (pretrained on SNUH visit records)
    './finetuned/ver9.1.4_521121_epoch2/4045',   # mbert cased (pretrained on SNUH visit records)
    './finetuned/ver11.1.4_521079_epoch2/4045',  # biobert (pretrained on SNUH visit records)
    './finetuned/ver12.1.4_407013_epoch2/4045',  # kobert (pretrained on SNUH visit records)
]

assert len(data_dirs)==len(checkpoints)

split_nums=[1]*len(checkpoints)

assert len(data_dirs)==len(checkpoints)

for d in range(len(data_dirs)):
    args = easydict.EasyDict({
        'data_dir':data_dirs[d],
        'eval_batch_size':4,             # train_batch_size/gradient_accumulation_steps
        'gradient_accumulation_steps':1,  # become update step
        'save_globalstep':1,             
        'checkpoint':checkpoints[d],
        'output_dir':checkpoints[d],
        'split_num':split_nums[d],

        'num_files_on_memory':2,

        'num_train_epochs':2,
        'warmup_proportion':0.5,
        'do_train':False,
        'do_eval':True,
        'do_lower_case':True,
        'fp16_allreduce':False,
        'seed':42,
        'no_cuda':False,
        'use_adasum':False,
        'learning_rate':0.00005,
        'weight_decay':0.01,
        'adam_epsilon':1e-8,
        'gradient_predivide_factor':1.0,
        'fp16':False,
        'max_grad_norm':1.0,
        'num_files_on_memory':2,
        'local_rank':-1,
        'fp16_opt_level':'O1',
        'loss_scale':0,
        'server_ip':'',
        'server_port':'',    
        'server_ip':None,
        'server_port':None,
        'cuda':True,
        'vocab_size':30014, # ver1.5~ver2.2
        'hidden_size':768,
        'num_hidden_layers':12,
        'num_attention_heads':12,
        'hidden_act':'gelu',
        'intermediate_size':3072,
        'hidden_dropout_prob':0.1,
        'attention_probs_dropout_prob':0.1,
        'max_position_embeddings':512,
        'type_vocab_size':2,
        'initializer_range':0.02,
        'layer_norm_eps':1e-12,
        'gradient_checkpointing':None,
        'position_embedding_type':None,
        'use_cache':None,
        'classifier_dropout':None,

        'max_seq_length':512
    })

    main(args)