In [1]:
import torch
import numpy as np
import requests
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_image(url):
    try:
        return Image.open(requests.get(url, stream=True).raw).convert('RGB')
    except Exception as e:
        raise Exception(f"Error loading image: {str(e)}")

def setup_model(device="cuda" if torch.cuda.is_available() else "cpu"):
    try:
        model = CLIPModel.from_pretrained(
            "openai/clip-vit-base-patch32",
            device_map=device,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        )
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        return model, processor
    except Exception as e:
        raise Exception(f"Error setting up model: {str(e)}")

def process_image(image, processor, model, device):
    try:
        # Process inputs
        inputs = processor(
            text=["a photo of a cat", "a photo of a dog"],
            images=[image],
            return_tensors="pt",
            padding=True
        )
        
        # Move inputs to device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get predictions
        with torch.no_grad():
            outputs = model(**inputs)
            
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)
        return probs
    except Exception as e:
        raise Exception(f"Error processing image: {str(e)}")

def main():
    # Setup device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Load model and processor
    model, processor = setup_model(device)
    
    # Load and process image
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    image = load_image(url)
    
    # Get predictions
    probs = process_image(image, processor, model, device)
    
    # Print results
    labels = ["cat", "dog"]
    for i, label in enumerate(labels):
        print(f"Probability of {label}: {probs[0][i].item():.3f}")

if __name__ == "__main__":
    main()

Using device: cuda
Probability of cat: 0.995
Probability of dog: 0.005
