In [1]:
import os
import faiss
import numpy as np
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def initialize_models():
    # Image Captioning Model (BLIP)
    caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
    caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")

    # Text Embedding Model (Sentence Transformers)
    embedder = SentenceTransformer('all-MiniLM-L6-v2')

    return caption_processor, caption_model, embedder

In [3]:
def generate_caption(image_path, processor, model):
    try:
        image = Image.open(image_path).convert("RGB")
        inputs = processor(image, return_tensors="pt")
        out = model.generate(**inputs, max_length=50)
        caption = processor.decode(out[0], skip_special_tokens=True)
        return caption
    except Exception as e:
        print(f"Error generating caption: {e}")
        return None

In [4]:
def encode_text(text, embedder):
    return embedder.encode(text)

In [5]:
def initialize_faiss(dimension=384):
    # Create a FlatL2 index (exact search)
    index = faiss.IndexFlatL2(dimension)  # Matches Sentence Transformer output (384 dimensions)
    return index

In [6]:
def store_in_faiss(index, image_id, caption, caption_vector, metadata_list):
    vector = np.array([caption_vector])  # FAISS expects a 2D array
    index.add(vector)  # Add vector to index
    metadata_list.append({"image_id": image_id, "caption": caption})  # Store metadata separately

In [7]:
# def query_faiss(index, query_text, embedder, metadata_list, top_k=5):
#     query_vector = np.array([embedder.encode(query_text)])  # Encode query
#     distances, indices = index.search(query_vector, top_k)  # Search for top-k matches
#     results = []
#     for i, idx in enumerate(indices[0]):
#         metadata = metadata_list[idx]
#         results.append({
#             "image_id": metadata["image_id"],
#             "caption": metadata["caption"],
#             "distance": distances[0][i]
#         })
#     return results

def query_faiss(index, query_text, embedder, metadata_list, top_k=5):
    query_vector = np.array([embedder.encode(query_text)])  # Encode query
    distances, indices = index.search(query_vector, top_k)  # Search for top-k matches

    results = []
    for i, idx in enumerate(indices[0]):
        # Check if the index is valid and within the range of metadata_list
        if idx >= 0 and idx < len(metadata_list):
            metadata = metadata_list[idx]
            results.append({
                "image_id": metadata["image_id"],
                "caption": metadata["caption"],
                "distance": distances[0][i]
            })
        else:
            print(f"Warning: Invalid index {idx} encountered. Skipping.")
    return results

In [8]:
def main():
    # Configuration
    IMAGE_DIR = "./images"  # Directory containing images
    QUERY = "Show me images of dog"
    DIMENSION = 384  # Matches Sentence Transformer output

    # Initialize models
    caption_processor, caption_model, embedder = initialize_models()

    # Initialize FAISS index and metadata storage
    index = initialize_faiss(DIMENSION)
    metadata_list = []  # To store captions and IDs

    # Process Images and Store in FAISS
    for image_file in os.listdir(IMAGE_DIR):
        image_path = os.path.join(IMAGE_DIR, image_file)
        image_id = os.path.splitext(image_file)[0]

        # Generate caption
        caption = generate_caption(image_path, caption_processor, caption_model)
        if caption:
            print(f"Caption for {image_file}: {caption}")

            # Encode caption into vector
            caption_vector = encode_text(caption, embedder)

            # Store in FAISS
            store_in_faiss(index, image_id, caption, caption_vector, metadata_list)
            print(f"Stored {image_file} in FAISS.")

    # Query FAISS
    results = query_faiss(index, QUERY, embedder, metadata_list)
    print("\nQuery Results:")
    for result in results:
        print(f"Image ID: {result['image_id']}, Caption: {result['caption']}, Distance: {result['distance']}")

if __name__ == "__main__":
    main()



Error generating caption: cannot identify image file '/Users/Nisarg/Downloads/MindPalace/images/dog.jpg'

Query Results:
