In [41]:
from transformers import AutoTokenizer
from main import TextGenerationModel, KLUEDatamodule, parse_args
import pytorch_lightning as pl

MODEL_CKPT_NAME = "val_loss:val_loss=6.0762176513671875-val_bleu:val_bleu_score=0.0.ckpt"

args = {
    "dataset_name": "klue",
    "dataset_subset_name": "sts",
    "pretrained_model_name": "skt/kogpt2-base-v2",
    "max_seq_length": 64,
    "batch_size": 64,
    "num_workers": 1,
    "lr": 0.0001
}

tokenizer = AutoTokenizer.from_pretrained(
        args["pretrained_model_name"],
        bos_token='</s>',
        eos_token='</s>',
        unk_token='<unk>',
        pad_token='<pad>',
        mask_token='<mask>'
    )

model = TextGenerationModel.load_from_checkpoint(MODEL_CKPT_NAME, tokenizer=tokenizer, **args)
data = KLUEDatamodule.load_from_checkpoint(MODEL_CKPT_NAME, tokenizer=tokenizer, **args)

In [42]:
data.prepare_data()
data.setup("test")
batch = next(iter(data.test_dataloader()))

Found cached dataset klue (/Users/hi-jin/.cache/huggingface/datasets/klue/sts/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)


  0%|          | 0/2 [00:00<?, ?it/s]

Found cached dataset klue (/Users/hi-jin/.cache/huggingface/datasets/klue/sts/1.0.0/e0fc3bc3de3eb03be2c92d72fd04a60ecc71903f821619cb28ca0e1e29e4233e)


  0%|          | 0/2 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [43]:
batch

[tensor([[15872, 13088, 10836,  ...,     3,     3,     3],
         [ 9163, 13793, 36866,  ...,     3,     3,     3],
         [ 9961, 49491, 11732,  ...,     3,     3,     3],
         ...,
         [37367, 28478,  9208,  ...,     3,     3,     3],
         [22662,  7177,   387,  ...,     3,     3,     3],
         [19896,  7346,  8704,  ...,     3,     3,     3]]),
 tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 tensor([[14353, 15872, 48442,  ...,     3,     3,     3],
         [ 9163, 13793, 22564,  ...,     3,     3,     3],
         [20743, 31000,  6958,  ...,     3,     3,     3],
         ...,
         [10070,  9673, 16372,  ...,     3,     3,     3],
         [15983,  7409, 10358,  ...,     3,     3,     3],
         [35385,  9124, 13274,  ...,     3,     3,     3]])]

In [44]:
input_ids, attention_mask, labels = batch

In [45]:
tokenizer.batch_decode(input_ids)

['간단한 음식점이 세 군데 정도 있습니다.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '단점은 로마가 대중교통이 한국에 비해 완전 불편합니다.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '음악 볼륨 좀 조정하고 싶은데 뭐라고 명령어를 말해야할 지 모르겠네<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '그런데 화장실에 하수구냄새가 나는게 조금 별로이긴했어요.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa

In [47]:
tokenizer.batch_decode(labels)

['3개의 간단한 레스토랑이 있습니다.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '단점은 한국에 비해 로마의 대중교통이 매우 불편하다는 것입니다.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '무선 청소기 돌리는 방법을 설명해줘<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '주로 오전에 욕실에서 하수구냄새가 심하게  납니다.<pad><pad><pad><pad><pad><pad><pad><pad><p

In [48]:
loss, logits = model(input_ids, attention_mask, labels)
logits.shape

torch.Size([64, 64, 51200])

In [49]:
tokenizer.batch_decode(logits.squeeze(0).argmax(dim=-1))

['숙점은점은심한 있는 있어서 있습니다.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.요.',
 '단점은 로마의 대중 대중 대중 있다는 있다는 없다는 없다는<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '숙고고고고고고<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '숙실은 샤 하수 하수워구 냄새가 냄새가 조금요.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad