In [2]:
import requests
from PIL import Image
from transformers import AutoTokenizer, ViTImageProcessor, VisionEncoderDecoderModel
from transformers import pipeline

In [13]:
# 各種モデルの読み込み
# model = VisionEncoderDecoderModel.from_pretrained("./models/vit-gpt2-japanese-image-captioning_stair-captions/checkpoint-613500/")
# model = VisionEncoderDecoderModel.from_pretrained("./models/vit-gpt2-japanese-image-captioning_stair-captions-result")
model = VisionEncoderDecoderModel.from_pretrained("./models/vit-gpt2-japanese-image-finetuning/checkpoint-2000/")
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-medium", use_fast=False)
tokenizer.do_lower_case = True  # due to some bug of tokenizer config loading
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [14]:
# let's perform inference on an image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"

image = Image.open(requests.get(url, stream=True).raw)
pixel_values = image_processor(image, return_tensors="pt").pixel_values

# autoregressively generate caption (uses greedy decoding by default)
generated_ids = model.generate(pixel_values,
                               max_new_tokens=30, 
                               num_beams=5,
                               early_stopping=True,
                               do_sample=True,
                               temperature=1.2,
                               top_k=50,
                               top_p=0.95, 
                               no_repeat_ngram_size=3,
                               num_return_sequences=5)
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
for text in generated_texts:
    print(text)

ピンクの服を着た人の横に猫が座っている
ピンクの服を着た人と猫が座っている
ピンクの服を着た人の隣に猫が座っている
ピンクの服を着た男性が横になっている
ピンクの服を着た人と猫が座っている


In [15]:
# model.save_pretrained("./models/vit-gpt2-japanese-image-captioning_stair-captions-result/")

In [16]:
pl = pipeline("image-to-text",
             model=model,
             tokenizer=tokenizer,
             feature_extractor=image_processor,)

In [17]:
pl(url)

[{'generated_text': 'ピンクの服を着た女性が横になっている'}]

In [18]:
pl("./images/0202.png")

[{'generated_text': '男性が驚いている'}]