In [1]:
import faiss
import numpy as np
from PIL import Image
from clip import clip

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torchvision.datasets import CIFAR10
cifar10 = CIFAR10(root='../dataset', train=True, download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../dataset/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:11<00:00, 14988135.83it/s]


Extracting ../dataset/cifar-10-python.tar.gz to ../dataset


In [3]:
import torch
import torchvision.transforms as transforms
import torchvision.models as models


In [4]:
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [6]:
def get_embedding(image):
    image_tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        embedding = resnet(image_tensor).flatten()
    return embedding.numpy()

In [8]:
embeddings = [get_embedding(image[0]) for image in cifar10]

In [15]:
dim = 1000
index = faiss.IndexFlatL2(dim)

index.add(np.array(embeddings))

faiss.write_index(index, "cifar10_index.faiss")

In [17]:
query_image_path = "elephant.jpeg"
query_image = Image.open(query_image_path)
query_embedding = get_embedding(query_image)
distances, indices = index.search(np.array([query_embedding]), k=10)

In [18]:
print("Distances:")
print(distances)
print("\nIndices:")
print(indices)

Distances:
[[235.0747  328.4562  334.23212 344.09167 356.61414 366.49725 373.54117
  381.11877 411.14352 414.9278 ]]

Indices:
[[ 1406 43864   387 31382 43390  1425 11634  2099 36712 37097]]
