In [1]:
#====================================================================================================
# SKT Ko-GPT2 Text Generation 예제 
# => https://github.com/SKT-AI/KoGPT2
#====================================================================================================
import torch
from transformers import GPT2LMHeadModel, PreTrainedTokenizerFast

model_path='../model/gpt-2/kogpt-2-ft-summarizer-0509/'
#model_path='skt/kogpt2-base-v2'
device = torch.device("cuda:0")

In [2]:
# bos_token = </s> 인 이유는 => 보통 훈련된 모델들은 </s>를 시작 과 종료 토큰으로 모두 사용한다.
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path,
                                                   bos_token='</s>',
                                                   eos_token='</s>',
                                                   unk_token='<unk>',
                                                   pad_token='<pad>',
                                                   mask_token='<mask>')

tokenizer.tokenize("<s>안녕하세요. 한국어 GPT-2 입니다.")


['<s>', '▁안녕', '하', '세', '요.', '▁한국어', '▁G', 'P', 'T', '-2', '▁입', '니다.']

In [3]:
model = GPT2LMHeadModel.from_pretrained(model_path)
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(51200, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )


In [4]:
model.num_parameters()

125164032

In [12]:
text = '날씨'
input_ids = tokenizer.encode(text, return_tensors='pt')
print(input_ids)

gen_ids = model.generate(input_ids.to(device),
                         max_length=128,
                         repetition_penalty=2.0,
                         pad_token_id=tokenizer.pad_token_id,
                         eos_token_id=tokenizer.eos_token_id,
                         bos_token_id=tokenizer.bos_token_id,
                         use_cache=True)
print(gen_ids.shape)
print(gen_ids[0])

# skip_special_tokens=True 로 해서 <s>, </s> 토큰들은 출력안 시킬수도 있음
generated = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
print(generated)

tensor([[32016]])
torch.Size([1, 56])
tensor([32016,   739, 11294,  8363, 27073,  7514,  8263,  9355, 40612, 13168,
         8022,  9167, 16210, 28597, 10771,  9447,  9429,  7756, 10199,  9737,
         9327,  8705,  9249, 13312,  9368,  9430,  9623,  7426, 10956,  6958,
        12503,  8102,  8267,  9025,  9341, 11200,  6824, 10089,  7252,  7182,
         9724,   457,   459, 10473,  9837, 21049,   443,   405,  7657, 21334,
        13993,  8367, 31872, 13940,   375,     1], device='cuda:0')
날씨  봄철 미세먼지 때문에 호흡기 질환에 대한 관심이 높아지고 있는데 특히 황사 먼지가 심할 경우 호흡기를 통해 폐로 들어가기 쉬워질 수 있어 주의가 요구된다 <summarize>봄철에 공기청정기 필수



In [11]:
# 모델과 tokenizer 파일로 저장
#tokenizer.save_pretrained('kogpt2')
#model.save_pretrained('kogpt2')

In [7]:
# text generation 테스트 해보는 함수 
def eval_keywords(keywords):
    model.eval()
    
    for keyword in keywords:
        input_seq = "<s>" + keyword
        generated = torch.tensor(tokenizer.encode(input_seq)).unsqueeze(0)
        generated = generated.to(device)
        sample_outputs = model.generate(generated,
                                        do_sample = True,
                                        top_k=30,
                                        max_length=50,
                                        top_p=0.90,
                                        num_return_sequences=2)
        
        for i, sample_output in enumerate(sample_outputs):
            # skip_special_tokens=True 로 해서 <s>, </s> 토큰들은 출력안함
            print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))
            if i == 1:
                print("\n")
                                   

In [8]:
# 각 단어를 입력하여, text generation 해 봄
keywords = ["지미 카터","제임스 얼","수학"]
eval_keywords(keywords)

0: 지미 카터 전 미국 대통령의 부인  카스트로 의원은 지난 16일 방송된 JTBC  슈퍼맨 김희재  미스터  에서는 자신의 친오빠인 고 김민경을 대신해 딸로서 딸내미를 키워주
1: 지미 카터 전 미국 대통령이 지난달 15일 서울 신라호텔에서 열린 아시아 태평양 경제협력체(APEC) 정상회의에 참석해 한미동맹의 가치를 높이 평가하며 한국의 미래와 자유민주주의 체제의 미래에 대해 이야기하는 등 한반도 비핵화와 한반도


0: 제임스 얼터너티브 록밴드 R&B의 리더 존 레스터가 18일 방송 예정인 미국 CBS 뉴스에 출연해 미국   2019 올해의 밴드  를 꼽으며 자신이 가장 사랑하는 밴드라고 밝히며  존 레
1: 제임스 얼라이언스 CEO는 지난달 29일 삼성전자 서초사옥에서 삼성전자 관계자 및 임직원들이 모여  삼성전자서비스  삼성SDS와의 합병 추진에 대한 의견을 나누며  합병의 필요성에 대해 공감


0: 수학전문기업 이투스교육 과 수학전문교육기업 에듀윌  입시전략연구소가  2019년 9월 21일부터 27일까지  2019 수능시험  EBS 교재  수험생 지원 이벤트 를 실시한다고 밝혔다 <
1: 수학능력시험이 치러진 5일 수험생들은 시험장 앞에 줄을 선 채 고사장을 찾아 긴장한 모습이 역력했다 <summarize>수험생  늦잠 자려요     수


