In [23]:

import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
from io import BytesIO
import google.generativeai as genai

# Initialize CLIP model and processor
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

# Configure the Gemini API (replace with your API key)
genai.configure(api_key="")

# Define Places365 and COCO labels
places_labels = [
    "airport_terminal", "aquarium", "art_gallery", "badlands", "ballroom", "bamboo_forest", "banquet_hall",
    "bar", "baseball_field", "basketball_court", "beach", "bedroom", "boardwalk", "boat_deck", "bookstore",
    "botanical_garden", "bridge", "bus_interior", "campsite", "castle", "cemetery", "church_outdoor",
    "classroom", "clothing_store", "coffee_shop", "concert_hall", "conference_room", "construction_site",
    "corn_field", "corridor", "courtyard", "dining_room", "downtown", "fire_station", "forest_path",
    "garden", "gymnasium", "harbor", "hospital_room", "hotel_room", "ice_cream_parlor", "kitchen",
    "lake", "library", "living_room", "lobby", "market", "mountain_path", "museum", "nightclub",
    "office", "park", "parking_garage", "pharmacy", "playground", "restaurant", "river", "schoolyard",
    "shopping_mall", "stadium", "street", "subway_station", "swimming_pool", "temple", "theater",
    "train_interior", "valley", "waterfall", "zoo", "bedroom", "street"
]

coco_labels = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
    "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
    "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase",
    "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
    "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana",
    "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
    "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard",
    "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
    "teddy bear", "hair drier", "toothbrush"
]

def load_image(image_path):
    """
    Load an image from a URL or local file path.

    :param image_path: URL or local path of the image
    :return: Loaded PIL Image
    """
    try:
        if image_path.startswith(('http://', 'https://')):
            # Load image from URL
            response = requests.get(image_path)
            image = Image.open(BytesIO(response.content))
        else:
            # Load image from local file path
            image = Image.open(image_path)

        # Convert to RGB if needed
        if image.mode != 'RGB':
            image = image.convert('RGB')

        return image
    except Exception as e:
        print(f"Error loading image: {e}")
        raise

def analyze_image_with_clip(image, top_places_count=1, top_objects_count=2):
    """
    Analyze image using CLIP model to identify top places and objects.

    :param image: PIL Image to analyze
    :param top_places_count: Number of top places to return
    :param top_objects_count: Number of top objects to return
    :return: Tuple of top places and top objects
    """
    # Combine all labels
    all_labels = places_labels + coco_labels

    # Prepare inputs for CLIP
    inputs = processor(text=all_labels, images=image, return_tensors="pt", padding=True)

    # Run CLIP model
    with torch.no_grad():
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)

    # Get top predictions for places and objects
    top_places_idx = torch.topk(probs[:, :len(places_labels)], top_places_count).indices.squeeze()
    top_objects_idx = torch.topk(probs[:, len(places_labels):], top_objects_count).indices.squeeze()

    # Convert indices to labels
    top_places = [places_labels[idx] for idx in (top_places_idx.tolist() if top_places_count > 1 else [top_places_idx])]
    top_objects = [coco_labels[idx] for idx in (top_objects_idx.tolist() if top_objects_count > 1 else [top_objects_idx])]

    print(top_places)
    print(top_objects)

    return top_places, top_objects

def generate_image_caption(image_path, context_word=None):
    """
    Generate a detailed image caption using Gemini as a text LLM.

    :param image_path: Path or URL of the image
    :param context_word: Optional context word to influence the caption
    :return: Generated caption
    """
    # Load the image
    image = load_image(image_path)

    # Analyze image with CLIP
    top_places, top_objects = analyze_image_with_clip(image)

    # Initialize Gemini model for text generation
    model = genai.GenerativeModel('gemini-pro')

    # Prepare the prompt
    prompt_parts = [
        f"The image contains a scene identified as a {', '.join(top_places)} with objects such as {', '.join(top_objects)}. "
    ]

    # Add context word if provided
    if context_word:
        prompt_parts.append(f"Include the context of '{context_word}' in the description. ")

    prompt_parts.append("Provide a brief description of the image using just the provided words.")

    # Generate caption
    response = model.generate_content('\n'.join(prompt_parts))

    return response.text




In [22]:

    # Replace with your image URL or local file path
    image_url = "test_images/1.jpg"

    # Optional: get context word from user
    context_word = input("Enter a context word to influence the caption (or press Enter to skip): ").strip() or None

    # Generate and print the caption
    try:
        caption = generate_image_caption(image_url, context_word)
        print("\nGenerated Caption:")
        print(caption)
    except Exception as e:
        print(f"An error occurred: {e}")

Enter a context word to influence the caption (or press Enter to skip): bedroom 
['office']
['laptop', 'cat']

Generated Caption:
A cat sits on a desk in a bedroom, next to a laptop.
