In [22]:
import torch
from torchvision.transforms import (
    Compose,
    Grayscale,    
    Normalize,
    Resize,
    RandomHorizontalFlip,
    ToTensor,    
)
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [23]:
transform = {
    'head': Compose([
        RandomHorizontalFlip(),
        Grayscale(),
    ]),
    'large': Compose([
        Resize([128, 128]),
        ToTensor(),
        Normalize(mean=(0.5), std=(0.5)),
    ]),
    'medium': Compose([
        Resize([64, 64]),
        ToTensor(),
        Normalize(mean=(0.5), std=(0.5)),
    ]),
    'small': Compose([
        Resize([32, 32]),
        ToTensor(),
        Normalize(mean=(0.5), std=(0.5)),
    ]),
}

def img_transform(img):
    flipped = transform['head'](img)
    return {
        'large': transform['large'](flipped),
        'medium': transform['medium'](flipped),
        'small': transform['small'](flipped),
    }    
    
root = '/home/bobi/Desktop/db/ffhq-dataset/thumbnails'
ds = ImageFolder(root, transform=img_transform)
ds

Dataset ImageFolder
    Number of datapoints: 70000
    Root location: /home/bobi/Desktop/db/ffhq-dataset/thumbnails
    StandardTransform
Transform: <function img_transform at 0x7f71b33a14c0>

In [24]:
ds[1100][0]

{'large': tensor([[[ 0.3804,  0.3882,  0.4039,  ...,  0.3412,  0.3098,  0.2784],
          [ 0.3882,  0.3961,  0.4039,  ...,  0.5137,  0.5059,  0.5059],
          [ 0.3804,  0.3882,  0.3882,  ...,  0.5373,  0.5373,  0.5373],
          ...,
          [-0.4353, -0.4510, -0.4588,  ...,  0.5765,  0.5686,  0.5765],
          [-0.4118, -0.4510, -0.4902,  ...,  0.5765,  0.5686,  0.5608],
          [-0.4588, -0.4510, -0.4980,  ...,  0.5765,  0.5686,  0.5686]]]),
 'medium': tensor([[[ 0.3882,  0.3961,  0.3961,  ...,  0.4824,  0.4510,  0.4196],
          [ 0.3961,  0.3961,  0.3961,  ...,  0.5451,  0.5451,  0.5451],
          [ 0.3961,  0.3961,  0.4039,  ...,  0.5451,  0.5529,  0.5529],
          ...,
          [-0.4118, -0.1686,  0.1608,  ...,  0.5765,  0.5765,  0.5765],
          [-0.4275, -0.3098,  0.0431,  ...,  0.5686,  0.5765,  0.5765],
          [-0.4510, -0.4196, -0.0902,  ...,  0.5686,  0.5765,  0.5686]]]),
 'small': tensor([[[ 0.3961,  0.3961,  0.3961,  ...,  0.5294,  0.5216,  0.4980],


In [28]:
loader = DataLoader(ds, batch_size=8, shuffle=True)
loader

<torch.utils.data.dataloader.DataLoader at 0x7f71b2f2b520>

In [30]:
res = next(iter(loader))
res[0]['large'].shape, res[0]['medium'].shape, res[0]['small'].shape

(torch.Size([8, 1, 128, 128]),
 torch.Size([8, 1, 64, 64]),
 torch.Size([8, 1, 32, 32]))