In [8]:
#========================================================================================================================
# sentence-bert(sbert)에 CrossEncoder 방식 NLI 훈련 예시임
# => cross-encocoder 방식은 2개의 문장(문장1, 문장2)을 입력했을때 output으로 유사도(0~1값)을 출력해줌
#
# => 참고 : https://www.sbert.net/examples/training/cross-encoder/README.html
#         https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/cross-encoder/training_stsbenchmark.py  
#========================================================================================================================
import torch 
import os
import time
import numpy as np
from os import sys
from datetime import datetime
sys.path.append('../../')
from myutils import seed_everything, GPU_info, mlogging
from torch.utils.data import DataLoader
import math

from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CESoftmaxAccuracyEvaluator
from sentence_transformers import InputExample

seed_everything(111)
device = GPU_info()
logger =  mlogging(loggername="sbertcross", logfilename="../../../log/sbert-crossencocer-train-sts")

True
device: cuda:0
cuda index: 0
gpu 개수: 1
graphic name: NVIDIA A30
logfilepath:../../../log/sbert-crossencocer-train-sts_2022-04-22.log


In [4]:
train_file = '../../../korpora/kornli/snli_1.0_train.ko.tsv'
eval_file = '../../../korpora/kornli/xnli.dev.ko-1.tsv'

train_batch_size = 32
num_epochs = 1

In [6]:
# 훈련 데이터 불러오기
# => [sentence1, sentence2], labels 식으로 만듬
logger.info("Read kornli train/dev dataset")

label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
train_samples = []

with open(train_file, "rt", encoding="utf-8") as fIn:
    lines = fIn.readlines()
    for line in lines:
        s1, s2, label = line.split('\t')
        label = label2int[label.strip()]
        train_samples.append(InputExample(texts=[s1, s2], label=label))
  
# 평가 데이터 불러오기
dev_samples = []
with open(eval_file, "rt", encoding="utf-8") as fIn:
    lines = fIn.readlines()
    for line in lines:
        s1, s2, label = line.split('\t')
        label = label2int[label.strip()]
        dev_samples.append(InputExample(texts=[s1, s2], label=label))
        
print(train_samples[0:3])
print(dev_samples[0:3])

2022-04-22 14:05:31,836 - sbertcross - INFO - Read AllNLI train dataset


[<sentence_transformers.readers.InputExample.InputExample object at 0x7fad7c165760>, <sentence_transformers.readers.InputExample.InputExample object at 0x7fad7c1656a0>, <sentence_transformers.readers.InputExample.InputExample object at 0x7fad7c1655b0>]
[<sentence_transformers.readers.InputExample.InputExample object at 0x7faac1a56970>, <sentence_transformers.readers.InputExample.InputExample object at 0x7faac1a569d0>, <sentence_transformers.readers.InputExample.InputExample object at 0x7faac1a56a60>]


In [9]:
#We wrap train_samples, which is a list ot InputExample, in a pytorch DataLoader
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)

#During training, we use CESoftmaxAccuracyEvaluator to measure the accuracy on the dev set.
evaluator = CESoftmaxAccuracyEvaluator.from_input_examples(dev_samples, name='xnli.dev.ko.tsv')

In [10]:
# 모델 불러오기 
model_path = "../../../model/classification/bmc-fpt-wiki_20190620_mecab_false_0311-nouns-0327-ft-nli-0328/bertmodel"
model_save_path = 'output/crossencoder-nli-train-'+datetime.now().strftime("%Y-%m-%d_%H-%M")

model = CrossEncoder(model_path, num_labels=len(label2int))

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ../../../model/classification/bmc-fpt-wiki_20190620_mecab_false_0311-nouns-0327-ft-nli-0328/bertmodel and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
# 훈련시작
# => model_save_path에 모델과, 평가 CESoftmaxAccuracyEvaluator-dev_results.csv 파일 생성됨
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
logger.info("Warmup-steps: {}".format(warmup_steps))


# Train the model
model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=10000,
          warmup_steps=warmup_steps,
          output_path=model_save_path)

2022-04-22 14:06:22,856 - sbertcross - INFO - Warmup-steps: 1720
2022-04-22 14:06:22,856 - sbertcross - INFO - Warmup-steps: 1720


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/17193 [00:00<?, ?it/s]

In [12]:
# Load the stored model and evaluate its performance on STS benchmark dataset
# => 훈련되어서 저장된 s-bert 모델을 불러와서 성능 평가 해봄
##############################################################################
import time 

#model_save_path = "../../../model/bert/crossencoder-sts-train-2022-04-22_13-20"
#model_save_path = "../../../model/classification/bmc-fpt-wiki_20190620_mecab_false_0311-nouns-0327-ft-nli-0328/bertmodel"

test_file = '../../../korpora/kornli/xnli.test.ko-1.tsv'

# 테스트 데이터 불러옴 
label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
test_samples = []
with open(test_file, 'rt', encoding='utf-8') as fIn:
    lines = fIn.readlines()
    for line in lines:
        s1, s2, label = line.split('\t')
        label = label2int[label.strip()]
        test_samples.append(InputExample(texts=[s1, s2], label=label))

start = time.time()       
model = CrossEncoder(model_save_path)

evaluator = CESoftmaxAccuracyEvaluator.from_input_examples(test_samples, name='xnli.test.ko.tsv')
result = evaluator(model)

logger.info(f"\n")
logger.info(f"model path: {model_save_path}")
logger.info(f'=== result: {result} ===')
logger.info(f'=== 처리시간: {time.time() - start:.3f} 초 ===')
logger.info("==============================================")
logger.info("\n")

2022-04-22 14:44:56,280 - sbertcross - INFO - 

2022-04-22 14:44:56,280 - sbertcross - INFO - 

2022-04-22 14:44:56,282 - sbertcross - INFO - model path: output/crossencoder-nli-train-2022-04-22_14-06
2022-04-22 14:44:56,282 - sbertcross - INFO - model path: output/crossencoder-nli-train-2022-04-22_14-06
2022-04-22 14:44:56,284 - sbertcross - INFO - === result: 0.6578842315369261 ===
2022-04-22 14:44:56,284 - sbertcross - INFO - === result: 0.6578842315369261 ===
2022-04-22 14:44:56,285 - sbertcross - INFO - === 처리시간: 5.691 초 ===
2022-04-22 14:44:56,285 - sbertcross - INFO - === 처리시간: 5.691 초 ===
2022-04-22 14:44:56,287 - sbertcross - INFO - 

2022-04-22 14:44:56,287 - sbertcross - INFO - 

