In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import pandas as pd
import numpy as np

In [3]:
import skimage
import sklearn
import matplotlib.pyplot as plt

In [4]:
import cv2
import os
import PIL

In [5]:
import torch
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.transforms import Compose, RandomResizedCrop, RandomGrayscale, RandomHorizontalFlip, GaussianBlur, ColorJitter, RandomSolarize

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
from src.dataset import OurPatchLocalizationDataset, OriginalPatchLocalizationDataset

In [7]:
# set torch seed for reproducibility when using random image augmentations
torch.manual_seed(0)

<torch._C.Generator at 0x16ad0c730>

## Helper Functions

In [8]:
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = cv2.imread(os.path.join(folder,filename))
        if img is not None:
            images.append(img)
    return images

def image_to_patches(img):
    """Crop split_per_side x split_per_side patches from input image.
    Args:
        img (PIL Image): input image.
    Returns:
        list[PIL Image]: A list of cropped patches.
    """
    splits_per_side = 3  # split of patches per image side
    h, w = img.size()[1:]
    h_grid = h // splits_per_side
    w_grid = w // splits_per_side
    
    patches = [
        TF.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
        for i in range(splits_per_side) 
        for j in range(splits_per_side)
    ]
    
    return patches

## Load Data

In [10]:
tiny_imagenet = load_images_from_folder("./data/tiny-imagenet")
tiny_imagenet = [TF.to_pil_image(img) for img in tiny_imagenet]

## Experiments

In [11]:
# transform that will be applied to every raw image
TINY_IMAGENET_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])



# random augmentations
AUGMENTATIONS = [
    RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333)),
    RandomHorizontalFlip(p=0.5),
    ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    RandomGrayscale(p=0.5),
    GaussianBlur(kernel_size=23, sigma=(0.1, 0.2)),
    RandomSolarize(192, p=0.5),
]

In [30]:
ds = OurPatchLocalizationDataset(tiny_imagenet)

In [31]:
ds[0]

(tensor([[[[-1.6923e+00, -1.6923e+00, -1.7114e+00,  ..., -1.8400e+00,
            -1.8545e+00, -1.8545e+00],
           [-1.6923e+00, -1.6923e+00, -1.7114e+00,  ..., -1.8400e+00,
            -1.8545e+00, -1.8545e+00],
           [-1.7050e+00, -1.7050e+00, -1.7250e+00,  ..., -1.8339e+00,
            -1.8515e+00, -1.8515e+00],
           ...,
           [-1.0854e+00, -1.0854e+00, -1.0829e+00,  ..., -5.1854e-01,
            -4.8145e-01, -4.8145e-01],
           [-1.0965e+00, -1.0965e+00, -1.0939e+00,  ..., -6.0263e-01,
            -5.7222e-01, -5.7222e-01],
           [-1.0965e+00, -1.0965e+00, -1.0939e+00,  ..., -6.0263e-01,
            -5.7222e-01, -5.7222e-01]],
 
          [[-9.0669e-01, -9.0669e-01, -9.3452e-01,  ..., -6.9938e-01,
            -7.1662e-01, -7.1662e-01],
           [-9.0669e-01, -9.0669e-01, -9.3452e-01,  ..., -6.9938e-01,
            -7.1662e-01, -7.1662e-01],
           [-9.2442e-01, -9.2442e-01, -9.5304e-01,  ..., -6.9257e-01,
            -7.1382e-01, -7.1382e-01],


In [32]:
ds_loader = torch.utils.data.DataLoader(ds, batch_size=8, shuffle=True, num_workers=2)

In [33]:
for X, y in ds_loader:
    print(X.shape)
    print(y)
    break

torch.Size([8, 8, 9, 224, 224])
tensor([[0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7]])
