In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-medium")
tokenizer.do_lower_case = True
model = AutoModelForCausalLM.from_pretrained('./GPT-2/output02/')
model.to(device)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(32000, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout)

In [2]:
def generate_keigo(informal):
    input_text = '<s>'+informal+'[SEP]'
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    out = model.generate(input_ids, 
                         do_sample=False,
                         num_beams=5,
                         top_p=0.75,
                         top_k=0, 
                         num_return_sequences=1, 
                         max_length=256,
                         pad_token_id=tokenizer.pad_token_id,
                         bos_token_id=tokenizer.bos_token_id,
                         eos_token_id=tokenizer.eos_token_id,
                         bad_words_ids=[[1], [5]])
    for sent in tokenizer.batch_decode(out):
        sent = sent.split('[SEP]</s>')[1]
        sent = sent.replace('</s>', '')
        sent = sent.replace('\n', '')
        return sent

In [3]:
print(generate_keigo('先生の本を借りたい。'))

先生の本をお借りしたいです。


In [4]:
import pickle

conjugated_pkl_path = 'data/pred_all.pkl'

# load, organize, and tokenize data
with open(conjugated_pkl_path, 'rb') as f:
    all_emails = pickle.load(f)

for email_dict in all_emails:
    if email_dict['Task ID'] == 1:
        email_dict['GPT_pred'] = []
        for form_line in email_dict['Form']:
            email_dict['GPT_pred'].append(generate_keigo(form_line))

In [5]:
pickle_output_path = 'data/pred_all.pkl'

print("Saving pickle...")
with open(pickle_output_path, 'wb') as f:
    pickle.dump(all_emails, f)

Saving pickle...
