In [2]:
# Prep & dependencies

import requests, os
import torch, time
from transformers import CLIPModel, CLIPProcessor

from IPython.display import display, Image, HTML
import torch.nn.functional as F  # Import softmax from PyTorch

In [3]:
# Run to load base model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load base mapimg
mapimg = torch.load('mapimg.pt')
mapimg_idx = torch.load('mapimg_idx.pt')
mapimg_normalized = torch.load('mapimg_normalized.pt') 
print(mapimg.shape)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


torch.Size([562842, 512])


In [1]:
import gradio as gr
import torch
import time
import requests
from PIL import Image
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel

# Load the CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Load precomputed embeddings and index
mapimg = torch.load("mapimg.pt", map_location=device)  # Precomputed text embeddings
mapimg_idx = torch.load("mapimg_idx.pt")  # Corresponding image URLs/IDs
mapimg_normalized = torch.load('mapimg_normalized.pt')  # Normalized embeddings for cosine similarity

def fetch_images(image_urls):
    """Fetch and return images from URLs."""
    images = []
    for url in image_urls:
        try:
            response = requests.get(url)
            img = Image.open(BytesIO(response.content))
            images.append(img)
        except Exception as e:
            print(f"Failed to load image from {url}: {e}")
    return images

def text_based_search(query):
    """Search using text input."""
    if not query:
        return "Please enter a search query."

    if "map" not in query:
        query += " map"

    start_time = time.time()

    # Preprocess text
    text_inputs = processor(text=query, return_tensors="pt", padding=True).to(device)
    text_embeds = model.get_text_features(**text_inputs)
    text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

    # Compute similarities
    logit_scale = model.logit_scale.exp()
    logits = torch.matmul(mapimg_normalized, text_embeds.t()) * logit_scale
    top_logits, top_indices = torch.topk(logits, k=5, dim=0)

    # Retrieve top image URLs
    results = [mapimg_idx[idx.item()] for idx in top_indices]

    elapsed_time = time.time() - start_time
    images = fetch_images(results)
    
    return images

def image_based_search(image):
    """Search using an image input."""
    if image is None:
        return "Please upload an image."

    start_time = time.time()

    # Preprocess the image
    image_inputs = processor(images=image, return_tensors="pt", padding=True).to(device)
    image_embeds = model.get_image_features(**image_inputs)
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # Compute similarities
    logit_scale = model.logit_scale.exp()
    logits = torch.matmul(mapimg_normalized, image_embeds.t()) * logit_scale
    top_logits, top_indices = torch.topk(logits, k=5, dim=0)

    # Retrieve top image URLs
    results = [mapimg_idx[idx.item()] for idx in top_indices]

    elapsed_time = time.time() - start_time
    images = fetch_images(results)
    
    return images

def search_map(query_text=None, image=None):
    """Unified search function."""
    if query_text:
        return text_based_search(query_text)
    elif image:
        return image_based_search(image)
    else:
        return "Please provide either text or an image."

# Gradio UI
iface = gr.Interface(
    fn=search_map,
    inputs=[
        gr.Textbox(label="Enter search text (Optional)"),
        gr.Image(type="pil", label="Upload an image (Optional)")
    ],
    outputs=gr.Gallery(label="Search Results"),  # Display images in a gallery
    title="CLIP Map Search",
    description="Search maps using either text or an uploaded image."
)

# Launch Gradio UI
iface.launch(share=True)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://4c5b9ca9dfd7b5d524.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


