In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%load_ext snakeviz

In [3]:
import torch
import torchvision
from torchvision.io import ImageReadMode
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.transforms import Compose, RandomResizedCrop, RandomGrayscale, RandomHorizontalFlip, GaussianBlur, ColorJitter, RandomSolarize, ToPILImage, ToTensor, RandomCrop, CenterCrop, Resize, Normalize

# set torch seed for reproducibility when using random image augmentations
torch.manual_seed(0)

<torch._C.Generator at 0x110365850>

In [4]:
from src.dataset import OurPatchLocalizationDataset, OriginalPatchLocalizationDataset, image_to_patches
from src.transforms import RandomColorDropping, ColorProjection, IMAGENET_NORMALIZATION_PARAMS, IMAGENET_RESIZE

In [5]:
from typing import List, Union, Tuple
import PIL
import numpy as np
import pandas as pd
import os
from matplotlib import pyplot as plt

In [6]:
# load image paths of a previously sampled subset of ImageNet
image_paths = np.loadtxt("./data/imagenet_validation_RGB_10perc_subset.csv", dtype=str)

In [7]:
ds_orig_cache = OriginalPatchLocalizationDataset(image_paths, cache_images=True)
ds_orig_no_cache = OriginalPatchLocalizationDataset(image_paths, cache_images=False)

In [8]:
ds_our_cache = OurPatchLocalizationDataset(image_paths, cache_images=True)
ds_our_no_cache = OurPatchLocalizationDataset(image_paths, cache_images=False)

In [9]:
ds_orig_cache[0]

([tensor([[[ 1.2928,  1.2928,  1.2129,  ..., -0.2674, -0.2941, -0.2941],
           [ 1.2928,  1.2928,  1.2129,  ..., -0.2674, -0.2941, -0.2941],
           [ 1.2064,  1.2064,  1.1425,  ..., -0.2963, -0.3266, -0.3266],
           ...,
           [-1.1999, -1.1999, -1.1944,  ..., -1.8341, -1.9170, -1.9170],
           [-1.2189, -1.2189, -1.2039,  ..., -1.8475, -1.9295, -1.9295],
           [-1.2189, -1.2189, -1.2039,  ..., -1.8475, -1.9295, -1.9295]],
  
          [[ 1.1972,  1.1972,  1.1177,  ..., -0.3353, -0.3550, -0.3550],
           [ 1.1972,  1.1972,  1.1177,  ..., -0.3353, -0.3550, -0.3550],
           [ 1.1090,  1.1090,  1.0460,  ..., -0.3647, -0.3871, -0.3871],
           ...,
           [-1.2056, -1.2056, -1.2000,  ..., -1.6886, -1.7592, -1.7592],
           [-1.2304, -1.2304, -1.2151,  ..., -1.7045, -1.7731, -1.7731],
           [-1.2304, -1.2304, -1.2151,  ..., -1.7045, -1.7731, -1.7731]],
  
          [[ 1.1614,  1.1614,  1.0844,  ..., -0.3023, -0.3142, -0.3142],
           

In [10]:
ds_our_cache[0]

([tensor([[[ 0.8475,  0.8475,  0.7833,  ...,  0.2196,  0.2367,  0.2367],
           [ 0.8475,  0.8475,  0.7833,  ...,  0.2196,  0.2367,  0.2367],
           [ 0.8504,  0.8504,  0.7950,  ...,  0.2056,  0.2218,  0.2218],
           ...,
           [-1.1229, -1.1229, -1.1490,  ..., -1.6760, -1.6873, -1.6873],
           [-1.1218, -1.1218, -1.1493,  ..., -1.7308, -1.7297, -1.7297],
           [-1.1218, -1.1218, -1.1493,  ..., -1.7308, -1.7297, -1.7297]],
  
          [[ 0.6370,  0.6370,  0.5714,  ...,  0.0454,  0.0651,  0.0651],
           [ 0.6370,  0.6370,  0.5714,  ...,  0.0454,  0.0651,  0.0651],
           [ 0.6443,  0.6443,  0.5877,  ...,  0.0393,  0.0586,  0.0586],
           ...,
           [-1.0962, -1.0962, -1.1237,  ..., -1.5902, -1.6053, -1.6053],
           [-1.0787, -1.0787, -1.1078,  ..., -1.6367, -1.6389, -1.6389],
           [-1.0787, -1.0787, -1.1078,  ..., -1.6367, -1.6389, -1.6389]],
  
          [[ 0.4991,  0.4991,  0.4338,  ..., -0.0397, -0.0180, -0.0180],
           

In [11]:
%%timeit
ds_orig_cache[0]

696 µs ± 12.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [12]:
%%timeit
ds_orig_no_cache[0]

4.88 ms ± 451 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
%%timeit
ds_our_cache[0]

55.8 ms ± 2.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
%%timeit
ds_our_no_cache[0]

60.2 ms ± 1.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
%%snakeviz
ds_orig_cache[0]

 
*** Profile stats marshalled to file '/var/folders/1p/dx2hbmg90j579twjxstztgvh0000gn/T/tmp7os49oye'.
Embedding SnakeViz in this document...


In [18]:
%%snakeviz
ds_orig_no_cache[0]

 
*** Profile stats marshalled to file '/var/folders/1p/dx2hbmg90j579twjxstztgvh0000gn/T/tmpv4oit3_3'.
Embedding SnakeViz in this document...


In [19]:
%%snakeviz
ds_our_cache[0]

 
*** Profile stats marshalled to file '/var/folders/1p/dx2hbmg90j579twjxstztgvh0000gn/T/tmpoiwluvmt'.
Embedding SnakeViz in this document...


In [20]:
%%snakeviz
ds_our_no_cache[0]

 
*** Profile stats marshalled to file '/var/folders/1p/dx2hbmg90j579twjxstztgvh0000gn/T/tmp8_ymjtjg'.
Embedding SnakeViz in this document...
