In [None]:
from os import listdir, path

from transformers import(
    VisionEncoderDecoderModel, 
    ViTFeatureExtractor, 
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedTokenizerFast
)
import torch
import numpy as np
from PIL import Image

from PyKomoran import Komoran, DEFAULT_MODEL

from models.modified_sample.gpt2 import GPT2ModifiedSampleForCausalLM

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# load captioning model
encoder_model_name_or_path = "ddobokki/vision-encoder-decoder-vit-gpt2-coco-ko"
caption_tokenizer = PreTrainedTokenizerFast.from_pretrained(encoder_model_name_or_path)
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_model_name_or_path)
caption_model = VisionEncoderDecoderModel.from_pretrained(encoder_model_name_or_path).to(device)
caption_model.eval()

In [None]:
# load poem model
path = "/opt/ml/outputs/checkpoint-92"
poem_tokenizer = AutoTokenizer.from_pretrained(path)
poem_model = GPT2ModifiedSampleForCausalLM.from_pretrained(path).to(device)
poem_model.eval()

In [None]:
komoran = Komoran(DEFAULT_MODEL["FULL"])

In [None]:
img_dir = "/opt/ml/data/test_imgs"
for img_name in listdir(img_dir):
    img_path = path.join(img_dir, img_name)
    img = Image.open(img_path).convert("RGB")

    pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
    caption_ids = caption_model.generate(pixel_values, num_beams=5)
    caption = caption_tokenizer.decode(caption_ids, skip_special_tokens=True)[0]

    nouns = komoran.nouns(caption)
    input_text = f"@{', '.join(nouns)}@<d>\n"

    input_ids = poem_tokenizer.encode(input_text, return_tensors="pt").to(device)
    output_ids = poem_model.generate(
        input_ids,
        max_length=64,
        repetition_penalty=2.0,
        pad_token_id=poem_tokenizer.pad_token_id,
        eos_token_id=poem_tokenizer.eos_token_id,
        bos_token_id=poem_tokenizer.bos_token_id,
        do_sample=True,
        top_k=16,
        top_p=0.8,
    )
    poem = poem_tokenizer.decode(output_ids[0])

    img.show()
    print(f"caption: {caption}")
    print(f"poem:\n{poem}")