In [37]:
#====================================================================================================
# SKT Ko-GPT2 Text Generation 예제 
# => https://github.com/SKT-AI/KoGPT2
#====================================================================================================
import torch
from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast

model_path='../model/gpt-2/kogpt-2-ft-0504/'
#model_path='skt/kogpt2-base-v2'
device = torch.device("cuda")

In [16]:
# bos_token = </s> 인 이유는 => 보통 훈련된 모델들은 </s>를 시작 과 종료 토큰으로 모두 사용한다.
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path,
                                                   bos_token='</s>',
                                                   eos_token='</s>',
                                                   unk_token='<unk>',
                                                   pad_token='<pad>',
                                                   mask_token='<mask>')

tokenizer.tokenize("<s>안녕하세요. 한국어 GPT-2 입니다.")


['<s>', '▁안녕', '하', '세', '요.', '▁한국어', '▁G', 'P', 'T', '-2', '▁입', '니다.']

In [40]:
model = GPT2LMHeadModel.from_pretrained(model_path)
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(51200, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )


In [34]:
text = '오늘은 뭘 먹을까요?'
input_ids = tokenizer.encode(text, return_tensors='pt')
print(input_ids)

gen_ids = model.generate(input_ids,
                         max_length=128,
                         repetition_penalty=2.0,
                         pad_token_id=tokenizer.pad_token_id,
                         eos_token_id=tokenizer.eos_token_id,
                         bos_token_id=tokenizer.bos_token_id,
                         use_cache=True)
print(gen_ids.shape)
print(gen_ids[0])

generated = tokenizer.decode(gen_ids[0])
print(generated)

tensor([[10070,  8135,   739,  7570, 17003,  6969,  8084,   406]])
torch.Size([1, 14])
tensor([10070,  8135,   739,  7570, 17003,  6969,  8084,   406, 16518,  9863,
        16285,  9784,  8234,     1])
오늘은 뭘 먹을까요? 아니면 어떤 음식을 먹죠</s>


In [32]:
# 모델과 tokenizer 파일로 저장
#tokenizer.save_pretrained('kogpt2')
#model.save_pretrained('kogpt2')

In [41]:
# text generation 테스트 해보는 함수 
def eval_keywords(keywords):
    model.eval()
    
    for keyword in keywords:
        input_seq = "<s>" + keyword
        generated = torch.tensor(tokenizer.encode(input_seq)).unsqueeze(0)
        generated = generated.to(device)
        sample_outputs = model.generate(generated,
                                        do_sample = True,
                                        top_k=30,
                                        max_length=50,
                                        top_p=0.90,
                                        num_return_sequences=2)
        
        for i, sample_output in enumerate(sample_outputs):
            print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))
            if i == 1:
                print("\n")
                                   

In [42]:
# 각 단어를 입력하여, text generation 해 봄
keywords = ["지미 카터","제임스 얼","수학"]
eval_keywords(keywords)

0: 지미 카터 미국 대통령 당선인의 사생활 보호법안에 대해 반대하는 이유가 뭔지 말해보세요
1: 지미 카터 미국 대통령  이번에 미국 방문이래


0: 제임스 얼리  그거 내가 알기로는 햄버거 세트인 줄 알았어
1: 제임스 얼스터드라고 하네  제발 기억나


0: 수학 공부에만 집중하면 된다고 해서 공부하다가 힘들어
1: 수학 시험에서 B 학점을 받았는데  교수님이 수업 중간에 문제를 잘못 냈나 봐


