In [1]:
# 인퍼런스 설정
from ratsnlp.nlpbook.classification import ClassificationDeployArguments
args = ClassificationDeployArguments(
    pretrained_model_name="beomi/kcbert-base",
    downstream_model_dir="nlpbook/paircls",
    max_seq_length=64,
)

downstream_model_checkpoint_fpath: nlpbook/paircls\epoch=1-val_loss=0.82.ckpt


In [2]:
# 토크나이저 로드
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
    args.pretrained_model_name,
    do_lower_case=False,
)

In [3]:
# 체크포인트 로드
import torch
fine_tuned_model_ckpt = torch.load(
    args.downstream_model_checkpoint_fpath,
    map_location=torch.device("cpu"),
)

In [4]:
# BERT 설정 로드
from transformers import BertConfig
pretrained_model_config = BertConfig.from_pretrained(
    args.pretrained_model_name,
    num_labels=fine_tuned_model_ckpt["state_dict"]["model.classifier.bias"].shape.numel(),)

In [5]:
# BERT 모델 초기화
from transformers import BertForSequenceClassification
model = BertForSequenceClassification(pretrained_model_config)

In [6]:
# 체크포인트 주입하기
model.load_state_dict({k.replace("model.", ""): v for k, v in fine_tuned_model_ckpt["state_dict"].items()})

<All keys matched successfully>

In [7]:
# 평가 모드로 전환
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30000, 768, padding_idx=0)
      (position_embeddings): Embedding(300, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [8]:
# 인퍼런스 함수
def inference_fn(premise, hypothesis):
    inputs = tokenizer(
        [(premise, hypothesis)],
        max_length=args.max_seq_length,
        padding="max_length",
        truncation=True,
    )
    with torch.no_grad():
        # inputs를 파이토치 텐서로 변환하여 모델 계산
        outputs = model(**{k: torch.tensor(v) for k, v in inputs.items()})
        prob = outputs.logits.softmax(dim=1)  # 로짓에 소프트맥스 취하기
        entailment_prob = round(prob[0][0].item(), 2)
        contradiction_prob = round(prob[0][1].item(), 2)
        neutral_prob = round(prob[0][2].item(), 2)
        
        if torch.argmax(prob) == 0:
            pred = "참 (entailment)"
        elif torch.argmax(prob) == 1:
            pred = "거짓 (contradiction)"
        else:
            pred = "중립 (neutral)"
    
    return {
        'premise': premise,
        'hypothesis': hypothesis,
        'prediction': pred,
        'entailment_data': f"참 {entailment_prob}",
        'contradiction_data': f"거짓 {contradiction_prob}",
        'neutral_data': f"중립 {neutral_prob}",
        'entailment_width': f"{entailment_prob * 100}%",
        'contradiction_width': f"{contradiction_prob * 100}%",
        'neutral_width': f"{neutral_prob * 100}%",
    }

In [9]:
from ratsnlp.nlpbook.paircls import get_web_service_app
app = get_web_service_app(inference_fn)
app.run()

 * Serving Flask app 'ratsnlp.nlpbook.paircls.deploy'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit


 * Running on http://0eb1-211-204-110-53.ngrok.io
 * Traffic stats available on http://127.0.0.1:4040


127.0.0.1 - - [03/Feb/2023 23:14:55] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [03/Feb/2023 23:14:57] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [03/Feb/2023 23:15:07] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [03/Feb/2023 23:15:26] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [03/Feb/2023 23:15:39] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [03/Feb/2023 23:16:02] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [03/Feb/2023 23:16:33] "POST /api HTTP/1.1" 200 -
