In [2]:
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image

In [3]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

Downloading (…)lve/main/config.json: 100%|██████████| 4.61k/4.61k [00:00<00:00, 5.89MB/s]
Downloading pytorch_model.bin: 100%|██████████| 982M/982M [00:04<00:00, 213MB/s] 
Downloading (…)rocessor_config.json: 100%|██████████| 228/228 [00:00<00:00, 1.08MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 241/241 [00:00<00:00, 1.08MB/s]
Downloading (…)olve/main/vocab.json: 100%|██████████| 798k/798k [00:00<00:00, 3.60MB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 90.0MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 1.36M/1.36M [00:00<00:00, 6.45MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 120/120 [00:00<00:00, 443kB/s]


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [20]:
max_length = 16
num_beams = 4
num_return_sequences = 3

gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "num_return_sequences": num_return_sequences}
def predict_step(image_paths):
  images = []
  for image_path in image_paths:
    i_image = Image.open(image_path)
    if i_image.mode != "RGB":
      i_image = i_image.convert(mode="RGB")

    images.append(i_image)

  pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
  pixel_values = pixel_values.to(device)

  output_ids = model.generate(pixel_values, **gen_kwargs)

  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  preds = [pred.strip() for pred in preds]
  return preds

In [16]:
predict_step(['Image1.png'])

['a man kicking a soccer ball on a field',
 'a soccer player in action on the field',
 'a man kicking a soccer ball in the air']

In [21]:
predict_step(['Image2.png'])

['a woman is standing in a field with a horse',
 'a woman standing in a field with a horse',
 'a woman is standing in the middle of a field']

In [22]:
predict_step(['Image3.png'])

['a collage of photos showing a woman holding a sign',
 'a collage of photos showing a woman and a man',
 'a collage of photos showing a woman in a pink dress and a man']