# Content Based Image Retrieval

This notebook is an example of a simple CBIR system. Using a pretrained **Vision Transformer (ViT)** to extract image features comparisons are made between the feature vectors of the query image and each of the images in the database. The top 5 images ranked by cosine similarity are presented to the user.

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from timm import create_model
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image, to_tensor
from torchvision.utils import make_grid
from PIL import Image
from IPython.display import display

## Dataset
This is the initial image set which makes up the database which we query on. A function to fetch images as a tensor from the provided path is declared. Images are scaled to dimensions of 224 by 224 pixels with 3 channels and normalized. This transformation allows the image to be processed by the model.

In [None]:
DATASET_DIR = r"D:\Projs\datasets\small\monarch_150"

In [None]:
def path_to_tensor(img_path, img_transforms):
    image = Image.open(img_path).convert("RGB")
    tensor_image = img_transforms(image)
    return tensor_image

In [None]:
class ImgFolderDataset(Dataset):
    def __init__(self, img_dir, transforms):
        self.img_dir = img_dir
        self.transforms = transforms
        self.img_filenames = os.listdir(self.img_dir)

    def __len__(self):
        return len(self.img_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_filenames[idx])
        tensor_image = path_to_tensor(img_path, self.transforms)
        return tensor_image, img_path

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = ImgFolderDataset(DATASET_DIR, data_transforms)

## Model
Image to feature vector mapping is done with **ViT-Base (ViT-B/32)** described in [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). The pretrained model is fetched through the `timm` library. A slight modification is done to the original model, the last softmax layer is removed as described in the [Investigating the Vision Transformer Model for Image Retrieval Tasks](https://arxiv.org/abs/2101.03771).

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = create_model("vit_base_patch32_224", pretrained=True)
# Remove softmax layer, output is of size (1, 768)
model.head = nn.Identity()
model = model.to(device)

## Database
The image database is a dictionary which maps feature vectors to file paths. The dataset from above is inserted into the database.

In [None]:
img_dict = {}

In [None]:
with torch.no_grad():
    model.eval()
    for img, img_path in dataset:
        img = img[None,:].to(device)
        img_feats = model(img)
        img_dict[img_feats] = img_path

## Similarity criterion
Image similarity is calculated using cosine similarity between two feature vectors, in this case the query image feature vector and each individual feature vector from the database.

In [None]:
sim_fun = nn.CosineSimilarity(dim=1)

## Querying
Top 5 images from the database are retrieved based on similarity.

In [None]:
def find_top5(img_path, img_dict, sim_fun, model):
    img = path_to_tensor(img_path, data_transforms)
    with torch.no_grad():
        model.eval()
        img = img[None,:].to(device)
        feats = model(img)
    comparisons = [(sim_fun(feats, img2_feats), img2_path) for img2_feats, img2_path in img_dict.items()]
    top5 = sorted(comparisons, reverse=True)[:5]
    top5 = [Image.open(x[1]) for x in top5]
    return to_pil_image(make_grid([to_tensor(i) for i in top5], nrow=5, padding=10))

In [None]:
img_path = r"D:\Projs\datasets\small\monarch_651\0000_000009.png"
query = Image.open(img_path)
top5 = find_top5(img_path, img_dict, sim_fun, model)

In [None]:
display(query, top5)