
## 1. Load trained model
---

In [None]:
from sagemaker.estimator import Estimator

KIND = "unsupervised"
train_job_name = "kosimcse-roberta-base-unsupervised-2023-2023-10-05-08-45-35-241"
estimator = Estimator.attach(train_job_name)

In [None]:
import json, os

local_model_dir = 'model_from_sagemaker'

!rm -rf {local_model_dir}
if not os.path.exists(local_model_dir):
    os.makedirs(local_model_dir)

!aws s3 cp {estimator.model_data} {local_model_dir}/model.tar.gz
!tar -xzf {local_model_dir}/model.tar.gz -C {local_model_dir}
!rm -rf {local_model_dir}/model.tar.gz

In [None]:
import glob
import torch
import transformers
import numpy as np
from collections import OrderedDict
from transformers import BertForSequenceClassification, AutoTokenizer
import json, os
from serving_src.model import SimCSEConfig, SimCSEModel

local_model_dir = 'model_from_sagemaker'
with open(f'{local_model_dir}/config.json') as f:
    json_object = json.load(f)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
base_model = json_object["base_model"]
tokenizer = AutoTokenizer.from_pretrained(base_model)

config = SimCSEConfig(**json_object)
model = SimCSEModel(config)

### Load state dict
DDP(Distributed Data Parallel) 로 분산 훈련을 수행하였다면 key값의 module.을 삭제해야 합니다.

In [None]:
model_filename = glob.glob(f'{local_model_dir}/*.pt')[0]
state_dict = torch.load(model_filename)['model']

new_state_dict = {}
for key in state_dict:
    new_key = key.replace('module.','')
    new_state_dict[new_key] = state_dict[key]

model.load_state_dict(new_state_dict)
model = model.eval()

### Inference test

In [None]:
from src.infer import show_embedding_score
sentences = ['이번 주 일요일에 분당 이마트 점은 문을 여나요?',
             '일요일에 분당 이마트는 문 열어요?',
             '분당 이마트 점은 토요일에 몇 시까지 하나요']
show_embedding_score(tokenizer, model.cpu(), sentences)

<br>

## 2. (Optional) Push model to Hugging Face Hub
---

In [None]:
# from huggingface_hub import login
# login()

In [None]:
# hf_hub_path = f"KoSimCSE-{KIND}-{base_model.split('/')[-1]}"
# print(hf_hub_path)

In [None]:
# model.push_to_hub(
#     repo_id=hf_hub_path,
#     safe_serialization=True
# )
# tokenizer.push_to_hub(hf_hub_path, legacy_format='False')