In [None]:
#======================================================================================================
# sentence-bert nli로 훈련된 모델을 다시 sts(simentic textual similiarity) 파일로 훈련시킴
# => sentence-transformers 패키지를 이용하여 구현 함.(*pip install -U sentence-transformers 설치 필요)
#
# 도큐먼트 : https://www.sbert.net/index.html
# 소스참고 : https://github.com/BM-K/KoSentenceBERT-ETRI
#  => KoSentenceBERT-ETRI-master\KoSentenceBERT-ETRI-master\con_training_sts.py

# pip install -U sentence-transformers
#======================================================================================================

from torch.utils.data import DataLoader
import math
from sentence_transformers import models, losses
from sentence_transformers import SentencesDataset, LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from datetime import datetime
import sys
import os
import gzip
import csv
sys.path.append('..')
from myutils import seed_everything, GPU_info, mlogging

logger = mlogging(loggername="s-bert", logfilename="s-bert")
device = GPU_info()
seed_everything(111)

In [None]:
import os

# 기존 sbert 모델 경로
smodel_path = "../../data11/model/sbert/sbert-mydistilbertv1.1-sts-distil"

# sts 학습후 학습된 sbert 모델 저장할 경로
smodel_save_path = '../../data11/model/sbert/sbert-mydistilbertv1.1-sts-distil-sts'

# 평가시 cosine 유사도등 측정 결과값 파일 (similarity_evaluation_xxxx.xls) 저장될 경로
output_path = smodel_save_path
os.makedirs(output_path, exist_ok=True)

train_batch_size = 32
num_epochs = 100

# sts 학습, 평가, 테스트할 파일들
train_file = '../../data11/korpora/korsts/tune_train.tsv'
eval_file = '../../data11/korpora/korsts/tune_dev.tsv'
test_file = '../../data11/korpora/korsts/tune_test.tsv'

In [None]:
# sentence 모델을 불러옴
model = SentenceTransformer(smodel_path)
print(model)

In [None]:
# 평가, 테스트, 훈련 STS 데이터 불러옴
logger.info("Read STS train dataset")

train_samples = []
dev_samples = []
test_samples = []
with open(eval_file, 'rt', encoding='utf-8') as fIn:
    lines = fIn.readlines()
    for line in lines:
        s1, s2, score = line.split('\t')
        score = score.strip()
        score = float(score) / 5.0
        dev_samples.append(InputExample(texts= [s1,s2], label=score))

with open(test_file, 'rt', encoding='utf-8') as fIn:
    lines = fIn.readlines()
    for line in lines:
        s1, s2, score = line.split('\t')
        score = score.strip()
        score = float(score) / 5.0
        test_samples.append(InputExample(texts= [s1,s2], label=score))

with open(train_file, 'rt', encoding='utf-8') as fIn:
    lines = fIn.readlines()
    for line in lines:
        s1, s2, score = line.split('\t')
        score = score.strip()
        score = float(score) / 5.0
        train_samples.append(InputExample(texts= [s1,s2], label=score))

train_dataset = SentencesDataset(train_samples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)

In [None]:
# Development set: Measure correlation between cosine score and gold labels
logger.info("Read dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')

warmup_steps = math.ceil(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data for warm-up
logger.info("Warmup-steps: {}".format(warmup_steps))


In [None]:
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=smodel_save_path)

In [None]:
##############################################################################
#
# Load the stored model and evaluate its performance on STS benchmark dataset
#
##############################################################################

model = SentenceTransformer(smodel_save_path)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
test_evaluator(model, output_path=output_path)