In [1]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import faiss
import pickle

# Define image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load a pre-trained model (e.g., ResNet-50)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
model.eval()

Using cache found in /home/hamza-ubuntu/.cache/torch/hub/pytorch_vision_v0.10.0


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [2]:
def extract_features(image_path):
    """
    Extracts features (embeddings) from an image.

    Args:
        image_path (str): Path to the image file.

    Returns:
        np.ndarray: Extracted features as a numpy array, 
                    or None if there's an error.
    """
    try:
        img = Image.open(image_path)
        img_tensor = transform(img).unsqueeze(0)
        # Check if the tensor has 3 channels
        if img_tensor.shape[1] != 3:
            print(f"Image {image_path} has {img_tensor.shape[1]} channels, skipping.")
            return None

        with torch.no_grad():
            features = model(img_tensor).squeeze().numpy()
        return features
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        return None

In [3]:
def search_similar_images(query_image_path, index, image_paths):
    """
    Searches for similar images based on a query image.

    Args:
        query_image_path (str): Path to the query image.
        index (faiss.Index): The Faiss index containing image embeddings.
        image_paths (list): List of image paths.

    Returns:
        list: A list of paths to the most similar images.
    """
    query_embedding = extract_features(query_image_path)
    distances, indices = index.search(query_embedding.reshape(1, -1), k=5)
    similar_image_paths = [image_paths[i] for i in indices[0]]
    return similar_image_paths

In [None]:
def load_embeddings(filenamindexe):
  """
  Loads embeddings from a pickle file.

  Args:
      filename (str): The filename to load the embeddings from.

  Returns:
      np.ndarray: The loaded embeddings.
  """
  with open(filename, 'rb') as f:
    return pickle.load(f)

In [9]:
import csv
index = faiss.IndexFlatL2(model.fc.out_features)
embeddings = load_embeddings('mega_embeddings.pkl')
index.add(embeddings)
with open('image_paths.csv', 'r') as csvfile:
    reader = csv.reader(csvfile)
    image_paths_loaded = [row[0] for row in reader]
query_image_path = "test_image/vase.jpeg"
similar_images = search_similar_images(query_image_path, index, image_paths_loaded)
print(similar_images)

['SemArt/Images/08686-10vase.jpg', 'SemArt/Images/34434-stillif.jpg', 'SemArt/Images/44771-still_li.jpg', 'SemArt/Images/31879-y_woman1.jpg', 'SemArt/Images/34462-18rippl.jpg']
