In [3]:
import torch
import random
import numpy as np

from models.prefix_gptneox_model import PrefixGPTNeoXLMHeadModel
from utils.args_utils import Args

In [4]:
def set_seed(seed=100):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    n_gpu = torch.cuda.device_count()
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)
        
random_seed = 100
set_seed(random_seed)

In [5]:
# MODEL ARGS
args = Args()

args.pretrained_model = "EleutherAI/polyglot-ko-1.3b"
args.special_tokens = None
# Pretrained LM 웨이트 고정
args.freeze_plm = True
# Prefix 웨이트 학습
args.freeze_prefix = False

# hyperparams
args.prefix_dropout = 0.1
args.prefix_sequence_length = 8
args.mid_dim = 800

In [6]:
# Load Initial Model
model = PrefixGPTNeoXLMHeadModel(args)

processed_dict = torch.load("prefix_weights/gptneox_ep30_1r1e-5.bin")
print(processed_dict.keys())
# strict=False 여야 부분 웨이트만 로드
model.load_state_dict(processed_dict, strict=False)
model.eval()
print("Trained Model")

prefix-tuning sequence length is 8.
dict_keys(['input_tokens', 'wte.weight', 'control_trans.0.weight', 'control_trans.0.bias', 'control_trans.2.weight', 'control_trans.2.bias'])
Trained Model


In [7]:
model2 = PrefixGPTNeoXLMHeadModel(args)
model2.eval()
print("BASELINE COMPARISION MODEL")

prefix-tuning sequence length is 8.
BASELINE COMPARISION MODEL


In [16]:
# generate 함수 이용
s = "<s>새신발 샀는데 비와."
model_in = s + ' [A] '
inputs = model.tokenizer([model_in], max_length=256, return_tensors="pt", add_special_tokens=True)
print(inputs["input_ids"])
generated_ids = model.generate(inputs["input_ids"], \
        attention_mask = inputs["attention_mask"], \
        num_beams=1, min_length=32, do_sample = False, \
        max_length=63, repetition_penalty = 1.2, no_repeat_ngram_size = 3, early_stopping = True)

print(model.tokenizer.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True)[0])

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


tensor([[   31,    86,    33,  1569, 29093,  9144,   829,   563,   441,    17,
          5485,    36,    64,   224]])
<s>새신발 샀는데 비와. [A] okay!</s>비가 오면 안되겠네요.</S>비오는 날은 조심하세요.<s>(조심해서 나쁠 건 없죠)</A>그렇긴 하죠.<


In [17]:
# generate 함수 이용
s = "<s>새신발 샀는데 비와."
model_in = s + ' [A] '
inputs = model.tokenizer([model_in], max_length=256, return_tensors="pt", add_special_tokens=True)
print(inputs["input_ids"])
generated_ids = model2.generate(inputs["input_ids"], \
        attention_mask = inputs["attention_mask"], \
        num_beams=5, min_length=32, do_sample = False, \
        max_length=63, repetition_penalty = 1.2, no_repeat_ngram_size = 3, early_stopping = True)

print(model2.tokenizer.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True)[0])

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


tensor([[   31,    86,    33,  1569, 29093,  9144,   829,   563,   441,    17,
          5485,    36,    64,   224]])
<s>새신발 샀는데 비와. [A] o o [B] u u [C] e e e [D] d d [E] ed ed [F] f f [G] g g [H] he he [I] i
