In [None]:
#========================================================================================================================
# sentence-bert(sbert)에 CrossEncoder 방식 NLI 훈련 예시임
# => cross-encocoder 방식은 2개의 문장(문장1, 문장2)을 입력했을때 output으로 유사도(0~1값)을 출력해줌
#
# => 참고 : https://www.sbert.net/examples/training/cross-encoder/README.html
#         https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/cross-encoder/training_stsbenchmark.py  
#========================================================================================================================
import torch 
import os
import time
import numpy as np
from os import sys
from datetime import datetime
sys.path.append('../../')
from myutils import seed_everything, GPU_info, mlogging
from torch.utils.data import DataLoader
import math

from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CESoftmaxAccuracyEvaluator
from sentence_transformers import InputExample

device = GPU_info()
logger =  mlogging(loggername="sbertcross", logfilename="../../../log/sbert-crossencocer-train-sts")

In [None]:
train_file = '../../../data11/korpora/kornli/snli_1.0_train.ko.tsv'
eval_file = '../../../data11/korpora/kornli/xnli.dev.ko-1.tsv'

train_batch_size = 64
num_epochs = 3
lr = 3e-5 # default=2e-5 
eps = 1e-8 #lr이 0으로 나뉘어져 계산이 엉키는 것을 방지하기 위해 epsilion
seed = 111 

seed_everything(seed)

In [None]:
# 훈련 데이터 불러오기
# => [sentence1, sentence2], labels 식으로 만듬
logger.info("Read kornli train/dev dataset")

label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
train_samples = []

with open(train_file, "rt", encoding="utf-8") as fIn:
    lines = fIn.readlines()
    for line in lines:
        s1, s2, label = line.split('\t')
        label = label2int[label.strip()]
        train_samples.append(InputExample(texts=[s1, s2], label=label))
  
# 평가 데이터 불러오기
dev_samples = []
with open(eval_file, "rt", encoding="utf-8") as fIn:
    lines = fIn.readlines()
    for line in lines:
        s1, s2, label = line.split('\t')
        label = label2int[label.strip()]
        dev_samples.append(InputExample(texts=[s1, s2], label=label))
 
print(f'*train_len:{len(train_samples)}')
print(train_samples[0:3])
print(f'*dev_len:{len(dev_samples)}')            
print(dev_samples[0:3])

In [None]:
#We wrap train_samples, which is a list ot InputExample, in a pytorch DataLoader
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)

#During training, we use CESoftmaxAccuracyEvaluator to measure the accuracy on the dev set.
evaluator = CESoftmaxAccuracyEvaluator.from_input_examples(dev_samples, name='xnli.dev.ko.tsv')

In [None]:
# 모델 불러오기 
model_path = "bongsoo/albert-small-kor-v1"
model_save_path = '../../../data11/model/moco/cross/albert-small-kor-cross-nli' # +datetime.now().strftime("%Y-%m-%d_%H-%M")

model = CrossEncoder(model_path, num_labels=len(label2int))

In [None]:
# 훈련시작
# => model_save_path에 모델과, 평가 CESoftmaxAccuracyEvaluator-dev_results.csv 파일 생성됨
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
# evaluation_steps은 20%로 설정
evaluation_steps = warmup_steps * 2

logger.info("Warmup-steps: {}".format(warmup_steps))

# Train the model
model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=evaluation_steps,
          warmup_steps=warmup_steps,
          optimizer_params= {'lr': lr, 'eps': eps, 'correct_bias': False},
          save_best_model=True, # **기본 = True : eval 가장 best 모델을 output_Path에 저장함
          output_path=model_save_path)

In [None]:
# Load the stored model and evaluate its performance on STS benchmark dataset
# => 훈련되어서 저장된 s-bert 모델을 불러와서 성능 평가 해봄
##############################################################################
import time 

#model_save_path = "bongsoo/albert-small-kor-v1"
model_save_path = "../../../data11/model/moco/cross/albert-small-kor-cross-nli"

test_file = '../../../data11/korpora/kornli/xnli.test.ko-1.tsv'

# 테스트 데이터 불러옴 
label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
test_samples = []
with open(test_file, 'rt', encoding='utf-8') as fIn:
    lines = fIn.readlines()
    for line in lines:
        s1, s2, label = line.split('\t')
        label = label2int[label.strip()]
        test_samples.append(InputExample(texts=[s1, s2], label=label))

start = time.time()       
model = CrossEncoder(model_save_path, num_labels=len(label2int))

evaluator = CESoftmaxAccuracyEvaluator.from_input_examples(test_samples, name='xnli.test.ko.tsv')
result = evaluator(model)

logger.info(f"\n")
logger.info(f"model path: {model_save_path}")
logger.info(f'=== result: {result} ===')
logger.info(f'=== 처리시간: {time.time() - start:.3f} 초 ===')
logger.info("==============================================")
logger.info("\n")