In [2]:
# 인퍼런스 설정
from ratsnlp.nlpbook.classification import ClassificationDeployArguments
args = ClassificationDeployArguments(
    pretrained_model_name="beomi/kcbert-base",
    downstream_model_dir="nlpbook/doccls",
    max_seq_length=128,
)

downstream_model_checkpoint_fpath: nlpbook/doccls\epoch=1-val_loss=0.27.ckpt


In [3]:
# 토크나이저 로드
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
    args.pretrained_model_name,
    do_lower_case=False,
)

In [4]:
# 체크포인트 로드
import torch
fine_tuned_model_ckpt = torch.load(
    args.downstream_model_checkpoint_fpath,
    map_location=torch.device("cpu"),
)

In [5]:
# BERT 설정 로드
from transformers import BertConfig
pretrained_model_config = BertConfig.from_pretrained(
    args.pretrained_model_name,
    num_labels=fine_tuned_model_ckpt["state_dict"]["model.classifier.bias"].shape.numel(),
)

In [6]:
# BERT 모델 초기화
from transformers import BertForSequenceClassification
model = BertForSequenceClassification(pretrained_model_config)

In [9]:
# 체크포인트 주입하기
model.load_state_dict({k.replace("model.", ""): v for k, v in fine_tuned_model_ckpt["state_dict"].items()})

<All keys matched successfully>

In [11]:
# 평가 모드로 전환
'''
드롭아웃 등 학습 때만 사용하는 기법들을 무효화하는 역할
'''
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30000, 768, padding_idx=0)
      (position_embeddings): Embedding(300, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [13]:
# 인퍼런스 함수
def inference_fn(sentence):
    inputs = tokenizer(
        [sentence],
        max_length=args.max_seq_length,
        padding="max_length",
        truncation=True,
    )

    with torch.no_grad():
        outputs = model(**{k: torch.tensor(v) for k , v in inputs.items()})  # inputs를 텐서로 변환 후 모델 계산
        prob = outputs.logits.softmax(dim=1)  # 로짓에 소프트맥스 취하기
        positive_prob = round(prob[0][1].item(), 4)
        negative_prob = round(prob[0][0].item(), 4)
        pred = f"Positive" if torch.argmax(prob) == 1 else "Negative"
    
    return {
        'sentence': sentence,
        'prediction': pred,
        'positive_data': f"Positive: {positive_prob}",
        'negative_data': f"Negative: {negative_prob}",
        'positive_width': f"{positive_prob * 100}%",
        'negative_width': f"{negative_prob * 100}%",
    }

In [14]:
# 웹 서비스 시작하기
from ratsnlp.nlpbook.classification import get_web_service_app
app = get_web_service_app(inference_fn)
app.run()

 * Serving Flask app 'ratsnlp.nlpbook.classification.deploy'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit


 * Running on http://1b4b-211-204-110-53.ngrok.io
 * Traffic stats available on http://127.0.0.1:4040


127.0.0.1 - - [01/Feb/2023 15:31:42] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [01/Feb/2023 15:31:43] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [01/Feb/2023 15:32:05] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [01/Feb/2023 15:32:05] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [01/Feb/2023 15:32:17] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [01/Feb/2023 15:32:17] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [01/Feb/2023 15:32:30] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [01/Feb/2023 15:32:56] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [01/Feb/2023 15:33:02] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [01/Feb/2023 15:35:50] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [01/Feb/2023 15:36:28] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [01/Feb/2023 15:36:37] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [01/Feb/2023 15:36:44] "POST /api HTTP/1.1" 200 -
