In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import pandas as pd
import re

In [3]:
class Evaluator:
    def __init__(self, predictions_df):
        self.actual = list(predictions_df["Actual Text"])
        self.generated = list(predictions_df["Generated Text"].apply(lambda x : x[12:]))

    def _split_string_to_tokens(self, text):
        tokens_split_by_space = text.split()
        tokens_list = []
        for idx, token in enumerate(tokens_split_by_space):
            multi_token = token.replace('이고,',' ').replace('이다.',' ').replace('고,',' ').replace('다.',' ').split()
            for t in multi_token:
                #예외 처리 : extra id tag를 가진 경우
                if ('<extra_id' in t) :
                    id_idx = token.rfind('>')
                    t_except_id = token[id_idx+1:]
                    if len(t_except_id) > 0 :
                        tokens_list.append(t_except_id)

                #split후 비어있는 string인 경우
                elif (len(t)>0):
                    tokens_list.append(t)  
            
        return tokens_list

    
    def _make_entities_and_tags_list(self, text):
        kor_to_eng = {'사람' : 'PS', '위치' : 'LC', '기관' : 'OG', '날짜' : 'DT' ,'시간' : 'TI', '수량' : 'QT'}
        entities, tags = [], []
        tokens_list = self._split_string_to_tokens(text)
        entity = ''
        for idx, token in enumerate(tokens_list):
            #tag
            if (token in kor_to_eng.keys()):
                tags.append(kor_to_eng[token[:2]])
                entities.append(entity[:-2])
                entity = ''
            else:
                entity += token + ' '
        return entities, tags

    def _make_a_answer_dataframe(self,text_list):
        entities_ , tags_ = [], []
        for text in text_list:
            entities, tags = self._make_entities_and_tags_list(text)
            entities_.append(entities)
            tags_.append(tags)

        df = pd.DataFrame({'entity': entities_, 'tag' : tags_})
        return df

    def _check_correct_answer(self,true_entity, true_tag, pred_entities, pred_tags):
        check,result_entity,result_tag = False,'',''
        for idx, pred_entity in enumerate(pred_entities):
            if (true_entity == pred_entity) and (pred_tags[idx] == true_tag):
                return True, pred_entity, pred_tags[idx]
            elif (true_entity in pred_entity) and (pred_tags[idx] == true_tag):
                return True, pred_entity, pred_tags[idx]
            elif (pred_entity in true_entity) and (pred_tags[idx] == true_tag):
                return True, pred_entity, pred_tags[idx]
        return check, result_entity, result_tag

    def evaluate(self):
        y_pred = self._make_a_answer_dataframe(self.generated)
        y_true = self._make_a_answer_dataframe(self.actual)

        TP,true_total,pred_total = 0,0,0

        for idx, pred in y_pred.iterrows(): 
            pred_entities, pred_tags = pred["entity"], pred["tag"]
            true_entities, true_tags = y_true["entity"][idx], y_true["tag"][idx]

            true_total += len(true_entities)
            pred_total += len(pred_entities)

            #주의 : 추론 리스트에서 존재하는 것으로 확인된 것은 remove (substring : 동일한 NE가 2개 이상인 경우를 고려)
            for e_t, t_t in zip(true_entities,true_tags): #정답리스트의 NE가 추론리스트에 존재하는지
                result, e_p, t_p = self._check_correct_answer(e_t,t_t, pred_entities, pred_tags)
                if result == True:
                    pred_entities.remove(e_p)
                    pred_tags.remove(t_p)
                    TP += 1
      
        precision = TP / pred_total
        recall = TP / true_total
        f1_score = 2 / (1/precision + 1/recall)

        print("prediction 개수 : ", pred_total)
        print("ground truth 개수 : ",true_total)
        print("precision : ",precision)
        print("recall : ", recall)
        print("f1_score : ",f1_score)

        return precision, recall, f1_score,y_pred,y_true

In [4]:
predictions_df = pd.read_csv('./output/submission.csv')
e = Evaluator(predictions_df)
precision, recall, f1_score,y_pred,y_true = e.evaluate()

prediction 개수 :  12450
ground truth 개수 :  13075
precision :  0.9327710843373493
recall :  0.8881835564053537
f1_score :  0.9099314397649363
