In [1]:
import torch 
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os 
from torch.utils.data import Dataset
import pandas as pd 
from skimage import io 

In [8]:
class CatsAndDogsDataset(Dataset):
    def __init__(self, csvfile, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.annotations = pd.read_csv(csvfile)

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
        img = io.imread(img_path)
        label = torch.tensor(int(self.annotations.iloc[index, 1]))

        if self.transform is not None:
            img = self.transform(img)

        return (img, label)

In [16]:
my_transforms = transforms.Compose([
    transforms.ToPILImage(), # all the transformations work on this format
    transforms.Resize((256, 256)),
    transforms.RandomCrop((224, 224)),
    transforms.ColorJitter(brightness=0.5),
    transforms.RandomRotation(degrees=45),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.05),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    # (value - mean) / std, this does noting!
    transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
])

In [17]:
# my_transforms = transforms.ToTensor()

dataset = CatsAndDogsDataset(r'dataset\cats_dogs\cats_dogs.csv',
                             r'dataset\cats_dogs\cats_dogs_resized/',
                             transform=my_transforms)

In [18]:
save_path = r'dataset\cats_dogs_augmentations/'
img_num = 0
for _ in range(10):
    for img, label in dataset:
        save_image(img, save_path + 'img' + str(img_num) + '.png')
        img_num += 1
