In [7]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision import transforms
from PIL import Image
import numpy as np
import os
import json
import faiss
from typing import List, Tuple, Optional
from datetime import datetime
import math

import logging


# Set up logging to track what the system is doing
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

INFO:faiss.loader:Loading faiss with AVX2 support.
INFO:faiss.loader:Successfully loaded faiss with AVX2 support.


# Feature extractor

In [4]:
class ImageDataset(Dataset):
    """
    Dataset for batch processing of images.
    This class wraps a list of image paths and makes them compatible with PyTorch's DataLoader.
    """
    def __init__(self, image_paths: list, transform=None):
        """
        Initialize the dataset.
        
        Args:
            image_paths: List of file paths to images
            transform: Optional image transformations to apply
        """
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.image_paths)

    def __getitem__(self, idx):
        """
        Load and return a single image at the given index.
        
        Args:
            idx: Index of the image to load
            
        Returns:
            Tuple of (transformed_image, image_path)
        """
        image_path = self.image_paths[idx]
        try:
            # Open the image and convert to RGB (in case it's grayscale or RGBA)
            image = Image.open(image_path).convert('RGB')
            
            # Apply transformations if provided (e.g., resizing, normalization)
            if self.transform:
                image = self.transform(image)
            return image, image_path
        except Exception as e:
            logger.error(f"Error loading image {image_path}: {str(e)}")
            raise

In [5]:
class ImageFeatureExtractor:
    """
    Extracts visual features from images using a pre-trained Vision Transformer (ViT) model.
    These features are numerical representations that capture the visual content of images.
    """
    def __init__(self, device: Optional[str] = None):
        """
        Initialize the feature extractor with a ViT model.
        
        Args:
            device: Device to run model on ('cuda' or 'cpu'). Auto-detects if None.
        """
        # Automatically choose GPU if available, otherwise use CPU
        if device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
            
        logger.info(f"Using device: {self.device}")
        
        # Load pre-trained Vision Transformer model (trained on ImageNet dataset)
        self.model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        
        # Save the original forward method and replace it with our custom one
        # This allows us to extract features instead of getting classification outputs
        self.original_forward = self.model.forward
        self.model.forward = self._forward_features
        
        # Set model to evaluation mode (disables dropout, batch norm training, etc.)
        self.model.eval()
        
        # Move model to the appropriate device (GPU or CPU)
        self.model.to(self.device)
        
        # ViT-B/16 outputs 768-dimensional feature vectors
        self.feature_dim = 768
        
        # Define image preprocessing pipeline
        # These transforms prepare images for the ViT model
        self.transform = transforms.Compose([
            transforms.Resize(224),  # Resize shorter side to 224 pixels
            transforms.CenterCrop(224),  # Crop center 224x224 region
            transforms.ToTensor(),  # Convert to PyTorch tensor (scales to [0,1])
            # Normalize using ImageNet statistics (required for pre-trained models)
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
        
        logger.info(f"Initialized ViT feature extractor with dimension: {self.feature_dim}")

    def _forward_features(self, x):
        """
        Modified forward pass to extract feature embeddings instead of class predictions.
        
        The Vision Transformer works by:
        1. Splitting the image into patches
        2. Processing patches through transformer layers
        3. Using a special [CLS] token to represent the entire image
        
        We return the [CLS] token embedding as our feature vector.
        
        Args:
            x: Input image tensor
            
        Returns:
            Feature vector (embedding) for the image
        """
        # Process input through the model's initial layers
        x = self.model._process_input(x)
        n = x.shape[0]  # Batch size

        # Add the class token to the beginning of the sequence
        # This token will aggregate information about the entire image
        cls_token = self.model.class_token.expand(n, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        
        # Pass through transformer encoder layers
        x = self.model.encoder(x)

        # Return only the CLS token embedding (first token)
        # This is a 768-dimensional vector representing the entire image
        return x[:, 0]

    @torch.no_grad()  # Disable gradient computation for faster inference
    def extract_features(self, image_path: str) -> np.ndarray:
        """
        Extract a feature vector from a single image.
        
        This is the main method to use - give it an image path and get back
        a numerical vector that represents that image's visual content.
        
        Args:
            image_path: Path to the image file
            
        Returns:
            768-dimensional normalized feature vector as numpy array
        """
        try:
            # Load and transform the image
            image = Image.open(image_path).convert('RGB')
            image = self.transform(image).unsqueeze(0).to(self.device)
            
            # Extract features using the model
            features = self.model(image)
            
            # Convert from PyTorch tensor to numpy array and remove batch dimension
            features = features.cpu().numpy().squeeze()
            
            # Verify we got the expected feature dimension
            if features.shape != (self.feature_dim,):
                raise ValueError(f"Unexpected feature dimension: {features.shape}")
            
            # L2 normalization: make the feature vector have unit length
            # This makes distance comparisons more meaningful
            norm = np.linalg.norm(features)
            if norm > 0:
                features = features / norm
            
            logger.debug(f"Extracted features shape: {features.shape}")
            logger.debug(f"Features norm: {np.linalg.norm(features)}")
            
            return features
            
        except Exception as e:
            logger.error(f"Error extracting features from {image_path}: {str(e)}")
            raise

    def __del__(self):
        """
        Cleanup method called when the object is destroyed.
        Restores the original forward method to avoid side effects.
        """
        if hasattr(self, 'original_forward'):
            self.model.forward = self.original_forward

# Retrieval system

In [6]:
# Image Retrieval System

# This system finds similar images using AI-powered visual search. Think of it like
# a reverse image search, but for your own photo collection.

# How it works:

# 1. Feature Extraction: 
#    - Converts each image into a "fingerprint" (768 numbers that describe what's in the image)
#    - Similar images have similar fingerprints

# 2. Indexing: 
#    - Processes all your images and stores their fingerprints in a searchable database
#    - Uses FAISS (Facebook AI Similarity Search) for efficient searching
#    - Saves metadata about each image (filename, path, when it was indexed)

# 3. Search: 
#    - Takes a query image and finds images with similar fingerprints
#    - Returns the most similar images ranked by similarity
#    - Uses IndexIVFFlat for fast searching in large collections

# About IndexIVFFlat (Inverted File with Flat quantizer):
#     - Organizes images into clusters/regions (like organizing books by genre)
#     - When searching:
#       * First finds which clusters are most relevant to your query
#       * Only searches within those clusters (much faster!)
#     - Two-step setup:
#       * Training: Learns how to group images into clusters
#       * nprobe: How many clusters to search (more = slower but more accurate)
#     - Trade-off: Speed vs accuracy (might miss some matches, but usually finds good ones)
#     - Best for large image collections (1000+ images)

In [8]:
class ImageRetrievalSystem:
    """
    Main system for indexing and searching images based on visual similarity.
    """
    def __init__(self, 
                 feature_extractor: Optional[ImageFeatureExtractor] = None,
                 index_path: Optional[str] = None,
                 metadata_path: Optional[str] = None,
                 use_gpu: bool = False,
                 n_regions: int = 100,  # Number of clusters to divide images into
                 nprobe: int = 10):     # Number of clusters to search
        """
        Initialize the image retrieval system.
        
        Args:
            feature_extractor: Object to extract features from images
            index_path: Path to load existing FAISS index from
            metadata_path: Path to load existing metadata from
            use_gpu: Whether to use GPU acceleration for FAISS
            n_regions: Number of clusters/regions for IVF index
            nprobe: Number of regions to search (higher = more accurate but slower)
        """
        # Use provided feature extractor or create a new one
        self.feature_extractor = feature_extractor or ImageFeatureExtractor()
        self.feature_dim = self.feature_extractor.feature_dim
        self.n_regions = n_regions
        self.nprobe = nprobe
        logger.info(f"Initializing retrieval system with dimension: {self.feature_dim}")
        
        # Dictionary to store information about each indexed image
        self.metadata = {}
        
        # Flag to track if the index has been trained
        self.is_trained = False
        
        # Load existing index if paths are provided
        if index_path and metadata_path:
            self.load(index_path, metadata_path)
        else:
            # Create a new FAISS index from scratch
            logger.info(f"Creating new IVF index with {n_regions} regions")
            
            # Quantizer: measures distances between feature vectors
            self.quantizer = faiss.IndexFlatL2(self.feature_dim)
            
            # Main index: uses IVF (Inverted File) for faster search
            # METRIC_L2 means we use Euclidean distance to measure similarity
            self.index = faiss.IndexIVFFlat(self.quantizer, self.feature_dim, 
                                          self.n_regions, faiss.METRIC_L2)
            
            # Set how many regions to search during queries
            self.index.nprobe = self.nprobe
            
            # Optionally move index to GPU for faster processing
            if use_gpu:
                try:
                    res = faiss.StandardGpuResources()
                    self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
                    logger.info("Successfully moved index to GPU")
                except Exception as e:
                    logger.warning(f"Failed to use GPU, falling back to CPU: {str(e)}")

    def index_images(self, 
                    image_dir: str, 
                    batch_size: int = 32,
                    num_workers: int = 4) -> None:
        """
        Process and index all images in a directory.
        
        This creates the searchable database of image features.
        
        Args:
            image_dir: Directory containing images to index
            batch_size: Number of images to process at once (unused in current implementation)
            num_workers: Number of parallel workers (unused in current implementation)
        """
        logger.info(f"Indexing images from {image_dir}")
        
        # Find all image files in the directory
        image_paths = [
            os.path.join(image_dir, f) for f in os.listdir(image_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))
        ]
        
        # Lists to store features and their corresponding paths
        features_list = []
        valid_paths = []
        
        # Process each image one by one
        for img_path in image_paths:
            try:
                # Extract feature vector from the image
                features = self.feature_extractor.extract_features(img_path)
                features_list.append(features)
                valid_paths.append(img_path)
                logger.info(f"Processed {img_path}")
            except Exception as e:
                # Skip images that can't be processed
                logger.error(f"Error processing {img_path}: {str(e)}")
                continue
        
        # Make sure we successfully processed at least some images
        if not features_list:
            raise ValueError("No valid features extracted from images")
            
        # Stack all feature vectors into a single numpy array
        # Shape will be (num_images, 768)
        all_features = np.stack(features_list)
        logger.info(f"Feature array shape: {all_features.shape}")
        logger.info(f"Feature stats - Min: {all_features.min():.4f}, Max: {all_features.max():.4f}")
        
        # Train the IVF index to learn how to cluster the features
        # This only needs to be done once when first creating the index
        if not self.is_trained:
            logger.info("Training IVF index...")
            self.index.train(all_features)
            self.is_trained = True
            logger.info("Index training completed")
        
        # Add all feature vectors to the searchable index
        self.index.add(all_features)
        logger.info(f"Total vectors in index: {self.index.ntotal}")
        
        # Store metadata for each indexed image
        # This allows us to retrieve the original image path later
        for idx, path in enumerate(valid_paths):
            self.metadata[str(idx)] = {
                'path': path,
                'filename': os.path.basename(path),
                'indexed_at': datetime.now().isoformat()
            }
        
        logger.info(f"Successfully indexed {len(valid_paths)} images")

    def search(self, 
              query_image_path: str,
              k: int = 5) -> List[Tuple[str, float]]:
        """
        Search for images similar to a query image.
        
        Args:
            query_image_path: Path to the image to search for
            k: Number of similar images to return
            
        Returns:
            List of tuples: (image_path, distance_score)
            Lower distance = more similar
        """
        logger.info(f"Searching for similar images to {query_image_path}")
        logger.info(f"Total images in index: {self.index.ntotal}")
        logger.info(f"Available metadata keys: {list(self.metadata.keys())}")
        
        # Make sure the index has been trained before searching
        if not self.is_trained:
            raise RuntimeError("Index has not been trained. Add images first.")
        
        # Extract features from the query image
        query_features = self.feature_extractor.extract_features(query_image_path)
        logger.info(f"Query feature shape: {query_features.shape}")
        
        # Don't try to return more results than we have images
        k = min(k, self.index.ntotal)
        
        # Search the index for similar feature vectors
        # Returns: distances (how different) and indices (which images)
        distances, indices = self.index.search(
            query_features.reshape(1, -1),  # Reshape to (1, 768) for batch format
            k
        )
        
        logger.info(f"Raw search results - distances: {distances[0]}")
        logger.info(f"Raw search results - indices: {indices[0]}")
        logger.info(f"Searched {self.nprobe} out of {self.n_regions} regions")
        
        # Convert indices to image paths using metadata
        results = []
        for dist, idx in zip(distances[0], indices[0]):
            str_idx = str(int(idx))
            if str_idx in self.metadata:
                # Store the image path and its distance score
                results.append((self.metadata[str_idx]['path'], float(dist)))
                logger.info(f"Match found: {self.metadata[str_idx]['path']} with distance {dist:.3f}")
            else:
                logger.warning(f"Index {idx} not found in metadata")
        
        # Sort by distance (smaller distance = more similar)
        results.sort(key=lambda x: x[1])
        
        if not results:
            logger.warning("No matches found!")
        else:
            logger.info(f"Found {len(results)} matches")
            
        return results

    def save(self, index_path: str, metadata_path: str) -> None:
        """
        Save the index and metadata to disk for later use.
        
        This allows you to index once and search many times without re-indexing.
        
        Args:
            index_path: Where to save the FAISS index
            metadata_path: Where to save the metadata JSON
        """
        # If using GPU, convert back to CPU for saving
        if faiss.get_num_gpus() > 0:
            self.index = faiss.index_gpu_to_cpu(self.index)
            
        # Save FAISS index (the searchable database)
        faiss.write_index(self.index, index_path)
        
        # Save metadata (image paths and info) as JSON
        with open(metadata_path, 'w') as f:
            json.dump(self.metadata, f)
            
        logger.info(f"Saved index with {self.index.ntotal} vectors")
        logger.info(f"Saved index to {index_path} and metadata to {metadata_path}")

    def load(self, index_path: str, metadata_path: str) -> None:
        """
        Load a previously saved index and metadata from disk.
        
        Args:
            index_path: Path to the saved FAISS index
            metadata_path: Path to the saved metadata JSON
        """
        logger.info(f"Loading index from {index_path}")
        
        # Load the FAISS index
        self.index = faiss.read_index(index_path)
        
        # Loaded indexes are already trained
        self.is_trained = True
        
        # Set nprobe for the loaded index (how many regions to search)
        if isinstance(self.index, faiss.IndexIVFFlat):
            self.index.nprobe = self.nprobe
            logger.info(f"Set nprobe to {self.nprobe} for loaded IVF index")
        
        # Load the metadata
        with open(metadata_path, 'r') as f:
            self.metadata = json.load(f)
            
        logger.info(f"Loaded index with {self.index.ntotal} vectors")
        logger.info(f"Metadata contains {len(self.metadata)} entries")

# Index and retrieve

In [9]:
# Main Application Runner

# This section provides a simple interface to use the image retrieval system.
# It handles two main tasks:

# 1. INDEXING (One-time setup):
#    - Processes all your images
#    - Creates a searchable database
#    - Saves it for future use
   
# 2. SEARCHING (Repeated use):
#    - Loads the saved database
#    - Finds images similar to your query
#    - Shows results with similarity scores

# Key Features:
# - Automatic optimization based on your collection size
# - Smart error handling
# - Flexible configuration options
# - User-friendly result display

In [10]:
# Fix for OpenMP warning on some systems
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

In [11]:
def calculate_optimal_regions(num_images: int) -> int:
    """
    Calculate the optimal number of clusters/regions for the IVF index.
    
    This is important for performance:
    - Too few regions: slow searches
    - Too many regions: inefficient clustering
    
    Rule of thumb:
    - Small collections (<100 images): sqrt(N) regions
    - Larger collections: 4*sqrt(N) regions
    
    Args:
        num_images: Total number of images to index
        
    Returns:
        Optimal number of regions
    """
    if num_images < 100:
        # For small datasets, use fewer regions
        n_regions = max(1, int(math.sqrt(num_images)))
    else:
        # For larger datasets, use more regions but cap at half the dataset size
        n_regions = min(int(4 * math.sqrt(num_images)), num_images // 2)
    return n_regions

In [12]:
def print_results(results):
    """
    Display search results in a user-friendly format.
    
    Shows each result with:
    - Rank number
    - Filename
    - Full path
    - Similarity score (higher = more similar)
    - Distance score (lower = more similar)
    
    Args:
        results: List of (image_path, distance) tuples
    """
    if not results:
        print("\nNo matches found!")
        return
        
    print("\nSearch Results:")
    print("-" * 50)
    for i, (path, distance) in enumerate(results, 1):
        # Convert distance to similarity (0-1 scale, higher is better)
        similarity = 1.0 / (1.0 + distance)
        filename = os.path.basename(path)
        
        print(f"{i}. Image: {filename}")
        print(f"   Full path: {path}")
        print(f"   Similarity Score: {similarity:.3f}")
        print(f"   Distance: {distance:.3f}")
        print("-" * 50)

In [13]:
def run_image_retrieval(
    task: str = "index",              # "index" or "search"
    image_dir: str = None,            # Directory of images to index
    query_image: str = None,          # Image to search for
    index_path: str = "image_index.faiss",
    metadata_path: str = "image_metadata.json",
    num_results: int = 5,             # How many similar images to return
    n_regions: int = None,            # Number of clusters (auto-calculated if None)
    nprobe: int = None,               # Number of clusters to search (auto-calculated if None)
    use_gpu: bool = False             # Use GPU acceleration
) -> None:
    """
    Main function to run the image retrieval system.
    
    This is the primary interface for both indexing and searching.
    
    INDEXING MODE (task="index"):
        - Processes all images in image_dir
        - Creates and saves a searchable index
        - Automatically optimizes settings based on collection size
        
    SEARCH MODE (task="search"):
        - Loads existing index
        - Finds similar images to query_image
        - Displays results with similarity scores
    
    Args:
        task: Either "index" or "search"
        image_dir: Directory containing images to index (required for indexing)
        query_image: Path to query image (required for searching)
        index_path: Where to save/load the FAISS index
        metadata_path: Where to save/load the metadata
        num_results: Number of similar images to return
        n_regions: Number of clusters for IVF (auto-calculated if None)
        nprobe: Number of clusters to search (auto-calculated if None)
        use_gpu: Whether to use GPU acceleration
    """
    try:
        if task.lower() == 'index':
            # ===== INDEXING MODE =====
            
            if not image_dir:
                raise ValueError("image_dir is required for indexing task")
            
            # Count how many images we have
            image_files = [f for f in os.listdir(image_dir) 
                         if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
            num_images = len(image_files)
            
            # Calculate optimal settings if not provided
            if n_regions is None:
                n_regions = calculate_optimal_regions(num_images)
            if nprobe is None:
                # Search about 1/4 of the regions by default
                nprobe = max(1, n_regions // 4)
                
            logger.info(f"Number of images: {num_images}")
            logger.info(f"Using {n_regions} regions and searching {nprobe} regions")
                
            # Create the retrieval system
            retrieval_system = ImageRetrievalSystem(
                n_regions=n_regions,
                nprobe=nprobe,
                use_gpu=use_gpu
            )
            
            # Process all images and build the index
            logger.info(f"Indexing images from {image_dir}")
            retrieval_system.index_images(image_dir=image_dir)
            
            # Save for later use
            logger.info("Saving index and metadata")
            retrieval_system.save(index_path, metadata_path)
            
        elif task.lower() == 'search':
            # ===== SEARCH MODE =====
            
            if not query_image:
                raise ValueError("query_image is required for search task")
            
            # Make sure the index exists
            if not os.path.exists(index_path) or not os.path.exists(metadata_path):
                raise ValueError(f"Index or metadata file not found. Please ensure both exist:\n"
                               f"Index: {index_path}\nMetadata: {metadata_path}")
                
            # Load the existing index
            logger.info(f"Loading existing index for search")
            retrieval_system = ImageRetrievalSystem(
                index_path=index_path,
                metadata_path=metadata_path,
                nprobe=nprobe if nprobe is not None else 10,
                use_gpu=use_gpu
            )
            
            # Search for similar images
            logger.info(f"Searching for similar images to {query_image}")
            results = retrieval_system.search(
                query_image_path=query_image,
                k=num_results
            )
            
            # Display the results
            print_results(results)
            
        else:
            raise ValueError("Task must be either 'index' or 'search'")
                
    except Exception as e:
        logger.error(f"Error: {str(e)}")
        raise