In [13]:
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import pickle
import torch
from gluonnlp.data import SentencepieceTokenizer
from model.net import KobertCRF
from data_utils.utils import Config
from data_utils.vocab_tokenizer import Tokenizer
from data_utils.pad_sequence import keras_pad_fn
from pathlib import Path

In [2]:
class DecoderFromNamedEntitySequence():
    def __init__(self, tokenizer, index_to_ner):
        self.tokenizer = tokenizer
        self.index_to_ner = index_to_ner

    def __call__(self, list_of_input_ids, list_of_pred_ids):
        input_token = self.tokenizer.decode_token_ids(list_of_input_ids)[0]
        pred_ner_tag = [self.index_to_ner[pred_id] for pred_id in list_of_pred_ids[0]]

        # ----------------------------- parsing list_of_ner_word ----------------------------- #
        list_of_ner_word = []
        entity_word, entity_tag, prev_entity_tag = "", "", ""
        for i, pred_ner_tag_str in enumerate(pred_ner_tag):
            if "B-" in pred_ner_tag_str:
                entity_tag = pred_ner_tag_str[-3:]

                if prev_entity_tag != entity_tag and prev_entity_tag != "":
                    list_of_ner_word.append({"word": entity_word.replace("▁", " "), "tag": prev_entity_tag, "prob": None})

                entity_word = input_token[i]
                prev_entity_tag = entity_tag
            elif "I-"+entity_tag in pred_ner_tag_str:
                entity_word += input_token[i]
            else:
                if entity_word != "" and entity_tag != "":
                    list_of_ner_word.append({"word":entity_word.replace("▁", " "), "tag":entity_tag, "prob":None})
                entity_word, entity_tag, prev_entity_tag = "", "", ""


        # ----------------------------- parsing decoding_ner_sentence ----------------------------- #
        decoding_ner_sentence = ""
        is_prev_entity = False
        prev_entity_tag = ""
        is_there_B_before_I = False

        for i, (token_str, pred_ner_tag_str) in enumerate(zip(input_token, pred_ner_tag)):
            if i == 0 or i == len(pred_ner_tag)-1: # remove [CLS], [SEP]
                continue
            token_str = token_str.replace('▁', ' ')  # '▁' 토큰을 띄어쓰기로 교체

            if 'B-' in pred_ner_tag_str:
                if is_prev_entity is True:
                    decoding_ner_sentence += ':' + prev_entity_tag+ '>'

                if token_str[0] == ' ':
                    token_str = list(token_str)
                    token_str[0] = ' <'
                    token_str = ''.join(token_str)
                    decoding_ner_sentence += token_str
                else:
                    decoding_ner_sentence += '<' + token_str
                is_prev_entity = True
                prev_entity_tag = pred_ner_tag_str[-3:] # 첫번째 예측을 기준으로 하겠음
                is_there_B_before_I = True

            elif 'I-' in pred_ner_tag_str:
                decoding_ner_sentence += token_str

                if is_there_B_before_I is True: # I가 나오기전에 B가 있어야하도록 체크
                    is_prev_entity = True
            else:
                if is_prev_entity is True:
                    decoding_ner_sentence += ':' + prev_entity_tag+ '>' + token_str
                    is_prev_entity = False
                    is_there_B_before_I = False
                else:
                    decoding_ner_sentence += token_str

        return list_of_ner_word, decoding_ner_sentence

## Namecard Test

In [3]:
# input json path
json_path = '/opt/ml/ocr/info_val.json'

In [4]:
import json

with open(json_path, 'r') as f:
    json_data = json.load(f)

anns = json_data['annotations']

In [5]:
all_words_list = []

for ann in anns:
    bboxs_per_image = ann['ocr']['word']

    words_in_image = ''
    for bbox in bboxs_per_image:
        words_in_image += bbox['text'] + ' '
    words_in_image += '.'
    all_words_list.append(words_in_image)

In [6]:
def namecard_test():
    model_dir = Path('./experiments/base_model_with_crf')
    model_config = Config(json_path=model_dir / 'config.json')

    # load vocab & tokenizer
    tok_path = "./ptr_lm_model/tokenizer_78b3253a26.model"
    ptr_tokenizer = SentencepieceTokenizer(tok_path)

    with open(model_dir / "vocab.pkl", 'rb') as f:
        vocab = pickle.load(f)
    tokenizer = Tokenizer(vocab=vocab, split_fn=ptr_tokenizer, pad_fn=keras_pad_fn, maxlen=model_config.maxlen)

    # load ner_to_index.json
    with open(model_dir / "ner_to_index.json", 'rb') as f:
        ner_to_index = json.load(f)
        index_to_ner = {v: k for k, v in ner_to_index.items()}

    # model
    model = KobertCRF(config=model_config, num_classes=len(ner_to_index), vocab=vocab)

    # load
    model_dict = model.state_dict()

    # checkpoint_path 수정 - Fix me
    checkpoint_path = '/opt/ml/ocr/ner/ner_kobert/experiments/base_model_with_crf_val/best-epoch-12-step-1000-acc-0.961.bin'
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    convert_keys = {}
    for k, v in checkpoint['model_state_dict'].items():
        new_key_name = k.replace("module.", '')
        if new_key_name not in model_dict:
            print("{} is not int model_dict".format(new_key_name))
            continue
        convert_keys[new_key_name] = v

    model.load_state_dict(convert_keys)
    model.eval()
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    decoder_from_res = DecoderFromNamedEntitySequence(tokenizer=tokenizer, index_to_ner=index_to_ner)

    tag_result = {
    'PER': [],
    'LOC': [],
    'ORG': [],
    'POH': [],
    'DAT': [],
    'TIM': [],
    'DUR': [],
    'MNY': [],
    'PNT': [],
    'NOH': []
    }

    for input_text in all_words_list:
        
        list_of_input_ids = tokenizer.list_of_string_to_list_of_cls_sep_token_ids([input_text])
        x_input = torch.tensor(list_of_input_ids).long().to(device)
        list_of_pred_ids = model(x_input)

        list_of_ner_word, decoding_ner_sentence = decoder_from_res(list_of_input_ids=list_of_input_ids, list_of_pred_ids=list_of_pred_ids)
        print("output>", decoding_ner_sentence)
        print("--------------------------------------------------------------------------------------------------------")
        print("")

        for tags in decoding_ner_sentence.replace('>', '<').split('<'):
            if ':' in tags:
                word, tag = tags.split(':')[-2:]
                if tag in tag_result.keys():
                    tag_result[tag].append(word)

    return tag_result



### NER Tagset


- 8개의 태그  
  - PER: 사람이름
  - LOC: 지명
  - ORG: 기관명
  - POH: 기타
  - DAT: 날짜
  - TIM: 시간
  - DUR: 기간
  - MNY: 통화
  - PNT: 비율
  - NOH: 기타 수량표현
- 개체의 범주
  - 개체이름: 사람이름(PER), 지명(LOC), 기관명(ORG), 기타(POH)
  - 시간표현: 날짜(DAT), 시간(TIM), 기간 (DUR)
  - 수량표현: 통화(MNY), 비율(PNT), 기타 수량표현(NOH)
result = namecard_test()

In [10]:
result = namecard_test()

In [11]:
print("Num of Tags")
for tag in result.keys():
    print(f"{tag} : {len(result[tag])}")

In [12]:
print('Example')
for tag in result.keys():
    print(f"{tag} : {result[tag][:50]}")
    # print(f"{tag} : {result[tag]}")