In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import rotate

In [2]:
class ImageDataset(Dataset):
    def __init__(self, images):
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        return self.images[index]

In [3]:
def collate_fn(batch):
    print("Here is the batch in `collate_fn`")
    print(batch)
    print("List of Tensor sizes in batch:")
    print([tensor.shape for tensor in batch])

    # get image dimensions, assuming all images are the same size
    # and it is number-of-channels (nc) first
    nc, h, w = batch[0].shape[-3], batch[0].shape[-2], batch[0].shape[-1]

    batch_size = len(batch)


    """to return batches of 3 (original image, augmented original, and another image)
    implies we need at least two separate images to work off of""" 
    if batch_size >= 2:
        print(f"Batch size: {batch_size}")
        # format: (batch_size, 3 for (original, augmented, another), num_channels, height, width)
        images_batch = torch.zeros((batch_size, 3, nc, h, w))
        
        for i in range(batch_size - 1):         
            first_image_index = i
            second_image_index = i + 1 # can be i and i + 1 if we toggle shuffle = True, so that it's random
            first_image = batch[first_image_index]
            second_image = batch[second_image_index]
            first_image_augmented = rotate(first_image, 180)

            images_batch[i, 0, :, :, :] = first_image
            images_batch[i, 1, :, :, :] = first_image_augmented
            images_batch[i, 2, :, :, :] = second_image
        
        images_batch[batch_size-1, 0, :, :, :] = batch[batch_size-1]
        images_batch[batch_size-1, 1, :, :, :] = rotate(batch[batch_size-1], 180)
        images_batch[batch_size-1, 2, :, :, :] = batch[0] 


    # otherwise, just create an (original image, augmented original) pair
    else:
        assert batch_size == 1
        images_batch = torch.zeros((batch_size, 2, nc, h, w))
        images_batch[0, 0, :, :, :] = batch[0]
        images_batch[0, 1, :, :, :] = rotate(batch[0], 180)
        
    return images_batch

In [33]:
(1, *[3, 4])

(1, 3, 4)

# Toy Data Experimentation

In [4]:
images = torch.randn(5, 1, 3, 3)
dataset = ImageDataset(images)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn, shuffle=True)

for batch in dataloader:
    batch = batch.view(batch.shape[0]*batch.shape[1], *batch.shape[2:]) 
    print(f"Batch output w/ shape {batch.shape}")
    print(batch)
    print("\n\n")

Here is the batch in `collate_fn`
[tensor([[[ 0.1320, -0.1185, -0.1576],
         [-0.5248, -0.3598, -1.1011],
         [ 0.2210,  0.0549,  2.9979]]]), tensor([[[-0.4236,  0.1173,  1.3468],
         [ 0.2980,  2.0543,  1.9087],
         [-0.7844,  0.9213, -0.8448]]])]
List of Tensor sizes in batch:
[torch.Size([1, 3, 3]), torch.Size([1, 3, 3])]
Batch size: 2
Batch output w/ shape torch.Size([6, 1, 3, 3])
tensor([[[[ 0.1320, -0.1185, -0.1576],
          [-0.5248, -0.3598, -1.1011],
          [ 0.2210,  0.0549,  2.9979]]],


        [[[ 2.9979,  0.0549,  0.2210],
          [-1.1011, -0.3598, -0.5248],
          [-0.1576, -0.1185,  0.1320]]],


        [[[-0.4236,  0.1173,  1.3468],
          [ 0.2980,  2.0543,  1.9087],
          [-0.7844,  0.9213, -0.8448]]],


        [[[-0.4236,  0.1173,  1.3468],
          [ 0.2980,  2.0543,  1.9087],
          [-0.7844,  0.9213, -0.8448]]],


        [[[-0.8448,  0.9213, -0.7844],
          [ 1.9087,  2.0543,  0.2980],
          [ 1.3468,  0.1173, -

In [5]:
images.shape

torch.Size([5, 1, 3, 3])

# CelebA

In [1]:
from torchvision.datasets import ImageFolder

class CustomImageFolder(ImageFolder):
    def __init__(self, root, transform=None):
        super(CustomImageFolder, self).__init__(root, transform)

    def __getitem__(self, index):
        path = self.imgs[index][0]
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        return img


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(), ])

train_kwargs = {'root': "./data/CelebA", 'transform': transform}
celeba_data = CustomImageFolder(**train_kwargs)

# 3DChairs