In [1]:
!pip install litserve

Collecting litserve
  Downloading litserve-0.2.3-py3-none-any.whl.metadata (16 kB)
Collecting fastapi>=0.100 (from litserve)
  Downloading fastapi-0.115.2-py3-none-any.whl.metadata (27 kB)
Collecting httpx (from litserve)
  Downloading httpx-0.27.2-py3-none-any.whl.metadata (7.1 kB)
Collecting uvicorn>=0.29.0 (from uvicorn[standard]>=0.29.0->litserve)
  Downloading uvicorn-0.31.1-py3-none-any.whl.metadata (6.6 kB)
Collecting starlette<0.41.0,>=0.37.2 (from fastapi>=0.100->litserve)
  Downloading starlette-0.39.2-py3-none-any.whl.metadata (6.0 kB)
Collecting h11>=0.8 (from uvicorn>=0.29.0->uvicorn[standard]>=0.29.0->litserve)
  Downloading h11-0.14.0-py3-none-any.whl.metadata (8.2 kB)
Collecting httptools>=0.5.0 (from uvicorn[standard]>=0.29.0->litserve)
  Downloading httptools-0.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)
Collecting python-dotenv>=0.13 (from uvicorn[standard]>=0.29.0->litserve)
  Downloading py

In [9]:
import requests
import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, GPT2TokenizerFast
import urllib.parse as parse
import os

# Verify URL function
def check_url(string):
    try:
        result = parse.urlparse(string)
        return all([result.scheme, result.netloc, result.path])
    except:
        return False

# Load an image from a URL or local path
def load_image(image_path):
    if check_url(image_path):
        return Image.open(requests.get(image_path, stream=True).raw)
    elif os.path.exists(image_path):
        return Image.open(image_path)
    else:
        raise ValueError(f"Invalid image path: {image_path}")

# HuggingFace API class for image captioning with accelerator support
class ImageCaptioningLitAPI:
    def setup(self, **kwargs):
        # Get parameters from kwargs with default values
        accelerator = kwargs.get('accelerator', 'auto')
        devices = kwargs.get('devices', 1)
        workers_per_device = kwargs.get('workers_per_device', 1)

        # Choose the device based on the accelerator input
        if accelerator == "cuda" and torch.cuda.is_available():
            self.device = "cuda"
        else:
            self.device = "cpu"

        # If accelerator="auto", auto-detect GPU or fallback to CPU
        if accelerator == "auto":
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        print(f"Using device: {self.device}")

        # Load the ViT Encoder-Decoder Model
        model_name = "nlpconnect/vit-gpt2-image-captioning"
        self.model = VisionEncoderDecoderModel.from_pretrained(model_name).to(self.device)

        # Load the corresponding Tokenizer
        self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)

        # Load the Image Processor
        self.image_processor = ViTImageProcessor.from_pretrained(model_name)

    # Generate image caption
    def predict(self, image_path):
        image = load_image(image_path)

        # Preprocessing the Image
        img = self.image_processor(image, return_tensors="pt").to(self.device)

        # Generating captions
        output = self.model.generate(**img)

        # Decode the output to generate the caption
        caption = self.tokenizer.batch_decode(output, skip_special_tokens=True)[0]

        return caption

# Instantiate the class and set up
api = ImageCaptioningLitAPI()

# Pass the accelerator and other parameters here
api.setup(accelerator="cuda", devices=1, workers_per_device=1)


Using device: cuda
Generated Caption: a man in a suit and hat standing in front of a building 


In [13]:
# Example 1
image_path = "/content/cbimage.png"

caption = api.predict(image_path)
print(f"Generated Caption: {caption}")

Generated Caption: a man in a suit and tie holding a red and white flag 


In [16]:
# Example 3
image_path = "/content/image (2).png"

caption = api.predict(image_path)
print(f"Generated Caption: {caption}")

Generated Caption: a view from a boat of a beach with a large body of water 
