In [1]:
# 인퍼런스 설정
from ratsnlp.nlpbook.generation import GenerationDeployArguments
args =  GenerationDeployArguments(
    pretrained_model_name="skt/kogpt2-base-v2",
    downstream_model_dir="nlpbook/generation",
)

downstream_model_checkpoint_fpath: nlpbook/generation\epoch=1-val_loss=2.29.ckpt


In [2]:
# 토크나어지 초기화
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast.from_pretrained(
    args.pretrained_model_name,
    eos_token="</s>",
)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


In [3]:
# 모델 불러오기
import torch
from transformers import GPT2Config, GPT2LMHeadModel
pretrained_model_config = GPT2Config.from_pretrained(
    args.pretrained_model_name,
)
model = GPT2LMHeadModel(pretrained_model_config)
fine_tuned_model_ckpt = torch.load(
    args.downstream_model_checkpoint_fpath,
    map_location=torch.device("cpu"),
)
model.load_state_dict({k.replace("model.", ""): v for k, v in fine_tuned_model_ckpt["state_dict"].items()})
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(51200, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )


In [4]:
# 인퍼런스
def inference_fn(
        prompt,
        min_length=10,
        max_length=20,
        top_p=1.0,
        top_k=50,
        repetition_penalty=1.0,
        no_repeat_ngram_size=0,
        temperature=1.0,
):
    try:
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids,
                do_sample=True,
                top_p=float(top_p),
                top_k=int(top_k),
                min_length=int(min_length),
                max_length=int(max_length),
                repetition_penalty=float(repetition_penalty),
                no_repeat_ngram_size=int(no_repeat_ngram_size),
                temperature=float(temperature),
           )
        generated_sentence = tokenizer.decode([el.item() for el in generated_ids[0]])
    except:
        generated_sentence = """처리 중 오류가 발생했습니다. <br>
            변수의 입력 범위를 확인하세요. <br><br> 
            min_length: 1 이상의 정수 <br>
            max_length: 1 이상의 정수 <br>
            top-p: 0 이상 1 이하의 실수 <br>
            top-k: 1 이상의 정수 <br>
            repetition_penalty: 1 이상의 실수 <br>
            no_repeat_ngram_size: 1 이상의 정수 <br>
            temperature: 0 이상의 실수
            """
    return {
        'result': generated_sentence,
    }

In [5]:
# 웹 서비스 시작하기
from ratsnlp.nlpbook.generation import get_web_service_app
app = get_web_service_app(inference_fn)
app.run()

 * Serving Flask app 'ratsnlp.nlpbook.generation.deploy'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit


 * Running on http://de1a-211-204-110-53.ngrok.io
 * Traffic stats available on http://127.0.0.1:4040


127.0.0.1 - - [06/Feb/2023 15:44:06] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:44:09] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:44:36] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:44:39] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:44:42] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:44:44] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:44:46] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:44:50] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:44:56] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:44:58] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:45:00] "POST /api HTTP/1.1" 200 -
127.0.0.1 - - [06/Feb/2023 15:45:04] "POST /api HTTP/1.1" 200 -
