<a href="https://colab.research.google.com/github/mynkpl1998/Attention-based-Image-Search/blob/main/Attn_based_image_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [93]:
from torchvision import transforms, datasets
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as F
from tqdm import tqdm
import requests
from tabulate import tabulate
import pickle
import os
import random
from torch.utils.data import Dataset, DataLoader

from typing import Any, Callable, List, Optional, Tuple, Union

In [135]:
class SquarePad:
	def __call__(self, image):
		w, h = image.size
		max_wh = np.max([w, h])
		hp = int((max_wh - w) / 2)
		vp = int((max_wh - h) / 2)
		padding = (hp, vp, hp, vp)
		return F.pad(image, padding, 0, 'constant')

class Caltech5(Dataset):

  def __init__(self,
               split_ratio: float,
               split: str,
               transform: Optional[Callable] = None,
               target_transform: Optional[Callable] = None):

    use_labels = ["motorbikes", "butterfly", "accordion", "airplanes", "brain"]

    assert split == "train" or split == "test"
    self._split_type = split
    assert len(use_labels) == 5
    assert split_ratio > 0.0 and split_ratio <= 1.0

    if transform is not None:
      self._img_transform = transform
    else:
      self._img_transform = None

    img_dims = (300, 300)
    self._base_img_transforms = v2.Compose([
        SquarePad(),
        v2.Resize(img_dims)
    ])

    # Complete 101 dataset
    dataloader = datasets.Caltech101(root="./", download=True)

    # Labels map
    labels_map = {}
    labels_map[3] = "motorbikes"
    labels_map[4] = "accordion"
    labels_map[5] = "airplanes"
    labels_map[13] = "brain"
    labels_map[16] = "butterfly"

    self._labels_map = labels_map

    # label text -> idx
    reverse_label_map = { v:k for k,v in labels_map.items() }
    self._reverse_label_map = reverse_label_map

    # Admissible classes
    classes = []

    for l in use_labels:
      classes.append(reverse_label_map[l])

    # Map: class --> Images (list)
    truncated_dataset = {}

    dataset_summary = []
    processed_file = "processed_caltech5.pkl"

    if os.path.exists(processed_file):
      print("Found processed dataset file. Loading.")
      with open(processed_file, 'rb') as f:
        # Unpickle the dictionary
        truncated_dataset = pickle.load(f)
    else:
      print("Creating processed dataset...")
      for idx in tqdm(range(0, len(dataloader))):
        img, label = dataloader[idx]
        if label in classes:
          if label not in truncated_dataset:
            truncated_dataset[label] = []
          truncated_dataset[label].append(img)

      with open(processed_file, 'wb') as f:
        # Pickle the dictionary
        pickle.dump(truncated_dataset, f)
        print("Saved processed dataset.")

    for label_idx in truncated_dataset.keys():
       dataset_summary.append([labels_map[label_idx], len(truncated_dataset[label_idx])])

    print("\n")
    print(tabulate(dataset_summary, headers=["Label", "Num. Examples"]))

    # Split for train and test
    self._train_split, self._test_split = self._split_dataset(truncated_dataset, split_ratio)
    split_summary = []
    split_summary.append(["Train", len(self._train_split)])
    split_summary.append(["Test", len(self._test_split)])
    print("\n")
    print(tabulate(split_summary, headers=["Split Summary", "Num. Examples"]))

  def idx2Label(self, index):
    return self._labels_map[index]

  def _split_dataset(self, dataset, split_ratio):
    train_split = {}
    test_split = {}

    train_examples = 0
    test_examples = 0
    for label_idx in dataset:
      split_index = int(len(dataset[label_idx]) * split_ratio)
      train_data = dataset[label_idx][0:split_index]
      test_data = dataset[label_idx][split_index:]
      train_split[label_idx] = train_data
      test_split[label_idx] = test_data

      train_examples += len(train_data)
      test_examples += len(test_data)

    train_split_merged = self._shuffled_dict_to_list(train_split)
    test_split_merged = self._shuffled_dict_to_list(test_split)

    assert len(train_split_merged) == train_examples
    assert len(test_split_merged) == test_examples

    return train_split_merged, test_split_merged

  def _shuffled_dict_to_list(self, dataset):
    merged_list = []
    for key in dataset.keys():
      for img in dataset[key]:
        merged_list.append((img, key))
    random.shuffle(merged_list)
    return merged_list

  def __repr__(self):
    return "Split: %s\nNum Examples: %d"%(self._split_type, self.__len__())

  def __getitem__(self, idx):

    if self._split_type == "train":
      _split = self._train_split[idx]
    else:
      _split = self._test_split[idx]

    img, class_idx = _split

    # Convert Grayscale to RGB
    if len(img.size) == 2:
      img = img.convert("RGB")

    # Apply padding and reize
    img = self._base_img_transforms(img)

    # Other transforms
    if self._img_transform is not None:
      img = self._img_transform(img)
    return np.moveaxis(np.array(img), -1, 0), class_idx

  def __len__(self):
    if self._split_type == "train":
      return len(self._train_split)
    else:
      return len(self._test_split)


In [137]:
# Data augumentation to apply to image

transforms = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomPerspective(distortion_scale=0.4, p=.5),
    v2.RandomInvert(p=0.5),
])

training_data = Caltech5(split_ratio=0.8, split="test", transform=transforms)
''

dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
itr = iter(dataloader)

for idx, (train_features, train_labels) in enumerate(itr):
  print(idx +1, train_features.shape)
  #print(train_labels)
  #break

"""
random_idxs = np.random.randint(0, len(training_data), size=10)
for idx in random_idxs:
  img, _class = training_data[idx]
  label = training_data.idx2Label(_class)
  plt.title(label)
  plt.imshow(img)
  plt.show()
"""

Files already downloaded and verified
Found processed dataset file. Loading.


Label         Num. Examples
----------  ---------------
motorbikes              798
accordion                55
airplanes               800
brain                    98
butterfly                91


Split Summary      Num. Examples
---------------  ---------------
Train                       1472
Test                         370
1 torch.Size([32, 3, 300, 300])
2 torch.Size([32, 3, 300, 300])
3 torch.Size([32, 3, 300, 300])
4 torch.Size([32, 3, 300, 300])
5 torch.Size([32, 3, 300, 300])
6 torch.Size([32, 3, 300, 300])
7 torch.Size([32, 3, 300, 300])
8 torch.Size([32, 3, 300, 300])
9 torch.Size([32, 3, 300, 300])
10 torch.Size([32, 3, 300, 300])
11 torch.Size([32, 3, 300, 300])
12 torch.Size([18, 3, 300, 300])


'\nrandom_idxs = np.random.randint(0, len(training_data), size=10)\nfor idx in random_idxs:\n  img, _class = training_data[idx]\n  label = training_data.idx2Label(_class)\n  plt.title(label)\n  plt.imshow(img)\n  plt.show()\n'

In [134]:
transforms

Compose(
      RandomHorizontalFlip(p=0.5)
      RandomVerticalFlip(p=0.5)
      RandomPerspective(p=0.5, distortion_scale=0.4, interpolation=InterpolationMode.BILINEAR, fill=0)
      RandomInvert(p=0.5)
)

(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=413x199>, 5)