In [None]:
import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer



In [None]:
# Load pre-trained model and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# Set device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_featur

In [None]:
# Generation parameters
max_length = 18
num_beams = 2
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

In [None]:
def predict_step(image_paths):
    # Load images and convert to RGB if necessary
    images = [Image.open(image_path).convert("RGB") for image_path in image_paths]

    # Extract pixel values and move to device
    pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values.to(device)

    # Generate captions
    output_ids = model.generate(pixel_values, **gen_kwargs)

    # Decode predictions
    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    return preds

In [None]:
predict_step(['/content/boysleeping.jpeg'])

['a young boy laying on a bed with a blanket']

In [None]:
predict_step(['/content/icecream.jpeg'])

['a little girl eating a chocolate frosted donut']

In [None]:
predict_step(['/content/videoblocks-depressed-young-woman-sitting-on-floor-feeling-desperate_s4na26umi_thumbnail-1080_01.png'])

['a woman sitting on the floor with her legs crossed']