<a href="https://colab.research.google.com/github/chowalex/photo-search/blob/main/Semantic_Search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is a demo of indexing a collection of images with CLIP, enabling semantic search of those images. Specifically it does the following:
1. Pulls images from an Azure blob storage container
2. Computes embeddings of each image with CLIP
3. Saves vectors in Pinecone
4. Computes the embedding for a text query, retrieves the top-K results from Pinecone, and displays those images

# Prerequisites
These secrets must be set as environment variables in the notebook:
* `HF_TOKEN`: Hugging Face token
* `PINECONE_API_KEY`: Pinecone API Key
* `AZURE_CONNECTION_STRING`: Connection string for the Azure blob storage container

In addition, set these constants based on the specific Azure and Pinecone setup:
* `CONTAINER_NAME`
* `PINECONE_INDEX`
* `PINECONE_NAMESPACE`

In [None]:
!pip install pinecone-client
!pip install azure-storage-blob

import torch
import os
import numpy as np
from PIL import Image
from pinecone import Pinecone, ServerlessSpec
from google.colab import userdata
from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel

In [None]:
# Set up CLIP model and parameters
DIMENSION = 512

device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id)
# Move model to device if possible
model.to(device)

In [20]:
# Azure and Pinecone settings
CONNECTION_STRING = userdata.get('AZURE_CONNECTION_STRING')
CONTAINER_NAME = 'search'
PINECONE_API_KEY = userdata.get('PINECONE_API_KEY')
PINECONE_INDEX = 'photo-search'
PINECONE_NAMESPACE = 'clip-demo'
BATCH_SIZE = 10

# Initialize the connection to Azure Blob Storage
if CONNECTION_STRING is None:
    raise ValueError("AZURE_CONNECTION_STRING environment variable not set.")
blob_service_client = BlobServiceClient.from_connection_string(CONNECTION_STRING)
container_client = blob_service_client.get_container_client(CONTAINER_NAME)

# Set up pinecone client
pc = Pinecone(api_key=PINECONE_API_KEY)

# Create the index if it does not already exist.
for index in pc.list_indexes():
  if index['name'] == PINECONE_INDEX:
    break
else:
  pc.create_index(
    name=PINECONE_INDEX,
    dimension=DIMENSION,
    metric="cosine",
    spec=ServerlessSpec(
        cloud="aws",
        region="us-east-1")
    )
index = pc.Index(PINECONE_INDEX)

In [None]:
blob_list = container_client.list_blobs()
blob_names = [blob.name for blob in blob_list]
num_images = len(blob_names)
print(f"Found {num_images} blobs in the container.")

def download_blob_as_image(blob_name):
  blob_client = container_client.get_blob_client(blob_name)
  blob_data = blob_client.download_blob().readall()
  image = Image.open(BytesIO(blob_data))
  return image

def get_embeddings_for_batch(blob_names):
  images = []
  for blob_name in blob_names:
    images.append(download_blob_as_image(blob_name))

  images_tokens = processor(
      text=None,
      images=images,
      return_tensors='pt'
  )['pixel_values'].to(device)

  images_tokens.shape
  images_emb = model.get_image_features(images_tokens)

  # Normalize embeddings
  images_emb = images_emb.detach().cpu().numpy()
  images_emb_norm = images_emb.T / np.linalg.norm(images_emb, axis=1)
  images_emb_norm = images_emb_norm.T
  images_emb_norm.shape
  return images_emb_norm

for i in range(0, num_images, BATCH_SIZE):
  blob_batch = blob_names[i : i+BATCH_SIZE]
  images_emb_norm = get_embeddings_for_batch(blob_batch)
  print(f"Successfully processed batch with normalized embeddings {images_emb_norm.shape}. Min: {images_emb_norm.min()}, Max: {images_emb_norm.max()}. Batch: {blob_batch}")

  # Save embeddings to pinecone in batch.
  # The vector ID = blob name.
  vectors = []
  for j, emb in enumerate(images_emb_norm):
    idx = i + j
    vectors.append({"id": blob_names[idx], "values": emb})

  index.upsert(
      vectors = vectors,
      namespace = PINECONE_NAMESPACE
  )

In [None]:
query = "aquarium"

tokens = processor(
    text=[query],
    padding=True,
    images=None,
    return_tensors='pt'
    ).to(device)

text_emb = model.get_text_features(**tokens)
text_emb = text_emb.detach().cpu().numpy()

# Normalize
norm_factor = np.linalg.norm(text_emb, axis=1)
norm_factor.shape
text_emb_norm = text_emb.T / norm_factor
text_emb_norm = text_emb_norm.T
print(f"Computed normalized text embedding {text_emb.shape}")

# Query pinecone
vector = text_emb_norm.flatten().tolist()
response = index.query(vector=vector, namespace=PINECONE_NAMESPACE, top_k=2)
print(response)
if 'matches' in response:
  for match in response['matches']:
    blob_name = match['id']
    display(download_blob_as_image(blob_name))