In [5]:
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
from io import BytesIO

# Function to download image with error handling
def download_image(url, filename):
    try:
        response = requests.get(url, timeout=10)
        if response.status_code != 200 or 'image' not in response.headers['content-type']:
            raise ValueError("URL did not return a valid image")
        img = Image.open(BytesIO(response.content)).convert("RGB")
        img.save(filename)
        return img
    except Exception as e:
        print(f"Error downloading image: {e}")
        raise

# Image URL and filename
image_url = "https://images.unsplash.com/photo-1561037404-61cd46aa615b"
image_filename = "dog.jpg"

# Download image
print("Downloading image...")
try:
    image = download_image(image_url, image_filename)
except Exception as e:
    print("Failed to download image. Please check the URL or try manually downloading.")
    exit(1)

# Load CLIP model
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model.to(device)

# Candidate captions
captions = [
    "A dog sits in a grassy field",
    "A cat sleeps on a couch",
    "A sunset over a beach",
    "Kids play soccer",
    "A bird flies in the sky"
]

# Process image and captions
inputs = clip_processor(
    text=captions,
    images=image,
    return_tensors="pt",
    padding=True
).to(device)

# Generate embeddings
with torch.no_grad():
    outputs = clip_model(**inputs)
    image_embedding = outputs.image_embeds  # [1, 512]
    text_embeddings = outputs.text_embeds  # [num_captions, 512]

# Compute cosine similarity (normalized dot product)
image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True)
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
similarities = (image_embedding @ text_embeddings.T).squeeze(0)  # [num_captions]

# Select the best caption
best_idx = similarities.argmax().item()
best_caption = captions[best_idx]
print("Generated Caption:", best_caption)
# note that the above is not completely accurate, but it is the closest label given the possible captions.

Downloading image...
Generated Caption: A dog sits in a grassy field
