## Libraries and Utilities

In [None]:
# Generic python libraries
import os
import math
import random
from PIL import Image
from random import sample

# PyTorch and numpy related
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split

# External tool (built by us) to format the dataset structure
if not os.path.exists('utils.py'):
  !wget https://raw.githubusercontent.com/humbertordrgs/VR_DL_2/master/utils/utils.py

from utils import load_images, get_class_map, get_processed_img, execution_time

## Dowloading Datasets

In [None]:
# Sketches dataset
if not os.path.exists('Sketch_EITZ.zip'):
  !wget https://www.dropbox.com/s/ut350iwgby9swk2/Sketch_EITZ.zip && unzip -q Sketch_EITZ.zip

# Gallery Image Retrieval for Train and validation
if not os.path.exists('Flickr25K.zip'):
  !wget https://www.dropbox.com/s/khbxruh3acq84eg/Flickr25K.zip && unzip -q Flickr25K.zip

# Gallery Image Retrieval for Test
# if not os.path.exists('Flickr15K.zip'):
#   !wget https://www.dropbox.com/s/q5ew09x4e3rsiht/Flickr15K.zip && unzip -q Flickr15K.zip


In [None]:
class SketchBasedImageRetrievalDataset(Dataset):
  
  @execution_time
  def __init__(
    self, sketch_folder_path, sketch_index_file, \
    image_gallery_folder_path, mapping_file_path, use_triplets = False \
  ):

    self.use_triplets = use_triplets
    self.sketches = []
    self.classes = []
    self.positive_images = []
    if use_triplets:
      self.negative_images = []

    class_map = get_class_map(sketch_folder_path + "/" + mapping_file_path)
    structured_images = load_images(image_gallery_folder_path)
    index_file_path = sketch_folder_path + "/" + sketch_index_file

    with open(index_file_path, "r") as sketch_file:
      sketch_lines = sketch_file.readlines()
      for line in sketch_lines:
        sketch_path, sketch_idx = line.split()
        sketch_path = sketch_folder_path + "/" + sketch_path
        sketch_class = class_map[sketch_idx]

        positive_random_image = sample(structured_images[sketch_class], 1)[0]
        positive_image_path = f"{image_gallery_folder_path}/{sketch_class}/{positive_random_image}"

        self.sketches.append(get_processed_img(sketch_path))
        self.positive_images.append(get_processed_img(positive_image_path))
        self.classes.append(int(sketch_idx))

        if use_triplets:
          while True:
            negative_random_class_idx = str(random.randint(0, 249))
            if negative_random_class_idx != sketch_idx:
              negative_class = class_map[negative_random_class_idx]
              negative_random_image = sample(structured_images[negative_class], 1)[0]
              negative_image_path = f"{image_gallery_folder_path}/{negative_class}/{negative_random_image}"
              self.negative_images.append(get_processed_img(negative_image_path))
              break
  def __len__(self):
    return len(self.sketches)

  def __getitem__(self, idx):
    if self.use_triplets:
      return (self.sketches[idx], self.positive_images[idx], self.negative_images[idx]), self.classes[idx]
    return (self.sketches[idx], self.positive_images[idx]), self.classes[idx]

In [None]:
train_dataset = SketchBasedImageRetrievalDataset("Sketch_EITZ", "train.txt", "Flickr25K", "mapping.txt", use_triplets=True)

Constructor: 196.35094 seconds


In [None]:
torch.save(train_dataset, 'SBIR_train_dataset.pt')

In [None]:
!zip SBIR_train_dataset.zip SBIR_train_dataset.pt

  adding: SBIR_train_dataset.pt (deflated 81%)
