# Modules and Global Variables

In [1]:
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
)

import torch, copy, json, re, os
from tqdm import tqdm
from module.score import evaluation_f1

In [2]:
def jsonload(fname, encoding="utf-8"):
    with open(fname, encoding=encoding) as f:
        j = json.load(f)
    return j

def jsondump(j, fname):
    with open(fname, "w", encoding="UTF8") as f:
        json.dump(j, f, ensure_ascii=False)

def jsonlload(fname, encoding="utf-8"):
    json_list = []
    with open(fname, encoding=encoding) as f:
        for line in f.readlines():
            json_list.append(json.loads(line))
    return json_list

In [3]:
print(f'torch.__version__: {torch.__version__}')
print(f'torch.cuda.is_available(): {torch.cuda.is_available()}')
NGPU = torch.cuda.device_count()
print(f'NGPU: {NGPU}')

torch.__version__: 1.7.1
torch.cuda.is_available(): True
NGPU: 4


# Paths and Modes

In [4]:
EVAL_MODE = False

RESULT_SAVE_NAME = 'klue_roberta_base_mlm_fine_tuned_uncleaned_v13.json'

ACD_CHECKPOINT = 'training_results/klue_roberta_base_mlm_fine_tuned_uncleaned_v13/acd/klue_roberta_base_mlm_fine_tuned_uncleaned_v13/checkpoint-23440'
ASC_CHECKPOINT = 'training_results/klue_roberta_base_mlm_fine_tuned_uncleaned_v13/asc/klue_roberta_base_mlm_fine_tuned_uncleaned_v13/checkpoint-1000'

TEST_DATA_PATH = 'dataset/nikluge-sa-2022-test.jsonl'
EVAL_DATA_PATH = 'dataset/nikluge-sa-2022-dev.jsonl'

In [5]:
if EVAL_MODE == True:
    TEST_DATA_PATH = EVAL_DATA_PATH
print('>>>>> >>>>> >>>>> ', TEST_DATA_PATH, ' <<<<< <<<<< <<<<<', '\n', sep='')

test_data = jsonlload(TEST_DATA_PATH)

if EVAL_MODE == True:
    for row in test_data:
        for annotation in row['annotation']:
            annotation.pop(1)
            
    true_data = copy.deepcopy(test_data)
    
    for row in test_data:
        row['annotation'] = []

    for idx, row in enumerate(true_data):
        print(row)
        if idx == 4:
            break
    print()
for idx, row in enumerate(test_data):
    print(row)
    if idx == 4:
        break

>>>>> >>>>> >>>>> dataset/nikluge-sa-2022-test.jsonl <<<<< <<<<< <<<<<

{'id': 'nikluge-sa-2022-test-00001', 'sentence_form': '하나 사려고 알아보는 중인데 맘에드는거 발견', 'annotation': []}
{'id': 'nikluge-sa-2022-test-00002', 'sentence_form': '동양인 피부톤과 잘 어울리고 우아한 분위기를 풍긴다네?', 'annotation': []}
{'id': 'nikluge-sa-2022-test-00003', 'sentence_form': '근데 이건 마르살라보다 더 지나친 색 같은데..', 'annotation': []}
{'id': 'nikluge-sa-2022-test-00004', 'sentence_form': '나스 색조가 다 그렇지만서도 어데이셔스 라인은 진짜 색 기막히게 뽑는것 같다', 'annotation': []}
{'id': 'nikluge-sa-2022-test-00005', 'sentence_form': '색상만 보면 이걸 어떻게 발라.. 싶겠지만 의외로 너무너무 괜찮다', 'annotation': []}


# Configs

In [6]:
entity_property_pair = [
    '본품#가격', '본품#다양성', '본품#디자인', '본품#인지도', '본품#일반', '본품#편의성', '본품#품질',
    '브랜드#가격', '브랜드#디자인', '브랜드#인지도', '브랜드#일반', '브랜드#품질',
    '제품 전체#가격', '제품 전체#다양성', '제품 전체#디자인', '제품 전체#인지도', '제품 전체#일반', '제품 전체#편의성', '제품 전체#품질',
    '패키지/구성품#가격', '패키지/구성품#다양성', '패키지/구성품#디자인', '패키지/구성품#일반', '패키지/구성품#편의성', '패키지/구성품#품질'
]

tf_id_to_name = ['True', 'False']
tf_name_to_id = {tf_id_to_name[i]: i for i in range(len(tf_id_to_name))}

polarity_id_to_name = ['positive', 'negative', 'neutral']
polarity_name_to_id = {polarity_id_to_name[i]: i for i in range(len(polarity_id_to_name))}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

len(entity_property_pair)

25

# Load Model and Tokenizer

In [7]:
acd_model = AutoModelForSequenceClassification.from_pretrained(ACD_CHECKPOINT)
acd_tokenizer = AutoTokenizer.from_pretrained(ACD_CHECKPOINT)

asc_model = AutoModelForSequenceClassification.from_pretrained(ASC_CHECKPOINT)
asc_tokenizer = AutoTokenizer.from_pretrained(ASC_CHECKPOINT)

# Inference

In [8]:
def predict_from_korean_form(acd_tokenizer, asc_tokenizer, acd_model, asc_model, data):
    acd_model.to(device)
    acd_model.eval()
    asc_model.to(device)
    asc_model.eval()

    for sentence in tqdm(data):
        form = sentence['sentence_form']
        sentence['annotation'] = []
        if type(form) != str:
            print("form type is wrong: ", form)
            continue
        for pair in entity_property_pair:
            acd_pair = pair
            acd_encoded = acd_tokenizer(form, acd_pair, truncation=True, return_tensors="pt")
            acd_encoded = {k:v.to(device) for k,v in acd_encoded.items()}

            with torch.no_grad():
                acd_outputs = acd_model(**acd_encoded)
            ce_predictions = acd_outputs['logits'].argmax(-1)
            ce_result = tf_id_to_name[ce_predictions[0]]

            if ce_result == 'True':
                asc_pair = pair
                asc_encoded = asc_tokenizer(form, asc_pair, truncation=True, return_tensors="pt")
                asc_encoded = {k:v.to(device) for k,v in asc_encoded.items()}

                with torch.no_grad():
                    asc_outputs = asc_model(**asc_encoded)
                pc_predictions = asc_outputs['logits'].argmax(-1)
                pc_result = polarity_id_to_name[pc_predictions[0]]

                if pair == '패키지/구성품#가격':
                    print(f'{pair} found.')
                    pair = '패키지/ 구성품#가격'
                    print(f'corrected as {pair}')

                sentence['annotation'].append([pair, pc_result])

    return data

In [9]:
pred_data = predict_from_korean_form(acd_tokenizer, asc_tokenizer, acd_model, asc_model, copy.deepcopy(test_data))

if EVAL_MODE == False:
    save_path = './'
    file_name = RESULT_SAVE_NAME

    jsondump(pred_data, os.path.join(save_path, file_name))
    pred_data = jsonload(os.path.join(save_path, file_name))
    
len(test_data), len(pred_data)

 79%|███████▉  | 1686/2127 [11:54<03:09,  2.33it/s]

패키지/구성품#가격 found.
corrected as 패키지/ 구성품#가격


 81%|████████  | 1722/2127 [12:09<03:00,  2.25it/s]

패키지/구성품#가격 found.
corrected as 패키지/ 구성품#가격


100%|██████████| 2127/2127 [15:03<00:00,  2.35it/s]


(2127, 2127)

# Evaluation

In [10]:
if EVAL_MODE == True:
    print('ACD_CHECKPOINT: ', ACD_CHECKPOINT)
    print('ASC_CHECKPOINT: ', ASC_CHECKPOINT)
    print('INFERENCE DATA: ', TEST_DATA_PATH)

    print('EVAL_MODE :', EVAL_MODE)

    result = evaluation_f1(true_data, pred_data)
    print(list(result.items())[0])
    print(list(result.items())[1])