In [1]:
from TripletFaceDataset import TripletFaceDataset
import torch
import torchvision.transforms as transforms
from PIL import Image
import collections

In [2]:
class Scale(object):
    """Rescales the input PIL.Image to the given 'size'.
    If 'size' is a 2-element tuple or list in the order of (width, height), it will be the exactly size to scale.
    If 'size' is a number, it will indicate the size of the smaller edge.
    For example, if height > width, then image will be
    rescaled to (size * height / width, size)
    size: size of the exactly size or the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
        else:
            return img.resize(self.size, self.interpolation)

In [3]:
transform = transforms.Compose([
                         Scale(224, 224),
                         transforms.ToTensor(),
                         # transforms.Normalize(mean = [ 0.5, 0.5, 0.5 ],
                         #                       std = [ 0.5, 0.5, 0.5 ])
                         ])

In [4]:
cuda = True
kwargs = {'num_workers': 0, 'pin_memory': True} if cuda else {}

dataroot = 'E:/College/Final Year Project/facenet_pytorch/data/train'
n_triplets = 1000
batch_size = 64

In [5]:
train_dir = TripletFaceDataset(dir=dataroot, n_triplets=n_triplets, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dir, batch_size=batch_size, shuffle=False, **kwargs)

Generating 1000 triplets


100%|███████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 83556.87it/s]


In [6]:
data_iter = iter(train_loader)