# Imports

In [1]:
from torchvision import transforms as T
from torch.utils.data.dataset import Dataset
from PIL import Image
import os
import torch
from pickle import HIGHEST_PROTOCOL

# Useful Constants

In [2]:
root_dir = '/Users/gbotev/Downloads/archive/memes'
new_size = 64

# Initialize `CustomDataset`

In [3]:
class CustomDataset(Dataset):
    """Custom dataset."""

    def __init__(self, root_dir, transforms=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transforms = transforms
        self.img_names = [name for name in os.listdir(root_dir) if os.path.isfile(os.path.join(self.root_dir, name))]
        self.num_imgs = len(self.img_names)

    def __len__(self):
        return self.num_imgs

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root_dir,
                                      self.img_names[index]))
        if self.transforms:
            img = self.transforms(img)
        return img

# Calculate Normalization and Initialize `CustomDataset`

In [4]:
dataset = CustomDataset(root_dir,
                        T.Compose([T.Resize(new_size),
                                   T.ToTensor()]))
means = []
stds = []
for img in dataset:
    means.append(torch.mean(img))
    stds.append(torch.std(img))
mean = torch.mean(torch.tensor(means))
std = torch.mean(torch.tensor(stds))
print(f'Mean: {mean}\n Std: {std}')
norm = T.Normalize(mean=mean, std=std)

Mean: 0.6056310534477234
 Std: 0.2573034465312958


In [5]:
dataset = CustomDataset(root_dir,
                        T.Compose([T.Resize(new_size),
                                   T.ToTensor(),
                                   norm]))

Sanity check to make sure we have 3,326 images.

In [6]:
len(dataset)

3326

# Save Dataset

In [7]:
torch.save(dataset, 'reddit_meme_dataset.pt', pickle_protocol=HIGHEST_PROTOCOL)