In [1]:
from PIL import Image
import glob
import torchvision.transforms as transforms
import numpy as np
import csv
import os

In [2]:
class MakeSquare:
    
    def __call__(self, image):
        w, h = image.size
        max_dim = max([w,h])
        horz = (max_dim - w) / 2
        vert = (max_dim - h) / 2
        l_pad = int(np.ceil(horz))
        r_pad = int(np.floor(horz))
        t_pad = int(np.ceil(vert))
        b_pad = int(np.floor(vert))
        return transforms.functional.pad(image, (l_pad, t_pad, r_pad, b_pad))

In [13]:
def get_transformed():
    for cat in ['alex', 'brendan', 'ethan', 'jerry', 'jon', 'josh', 'martin', 'mitchell', 'speero']:
        images = glob.glob(f'images/raw/{cat}/*.jpg')
        loc = f'images/transformed/{cat}/'
        if not os.path.exists(loc):
            os.makedirs(loc)
        for index, image in enumerate(images):
            im = Image.open(image)
            transform = transforms.Compose([
                transforms.Resize(224),
                transforms.CenterCrop(224)
            ])
            augment = transforms.Compose([
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.5)
            ])
            transformed = transform(im)
            transformed.save(f'{loc}{index}.jpg')
            for aug_num in range(1, 10):
                augmented = transform(augment(im))
                augmented.save(f'{loc}{index + aug_num * len(images)}.jpg')

In [15]:
def create_csv():
    with open('train.csv', 'w') as csvfile:
        writer = csv.writer(csvfile, lineterminator='\n')
        writer.writerow(['image', 'label'])
        for index, cat in enumerate(['alex', 'brendan', 'ethan', 'jerry', 'jon', 'josh', 'martin', 'mitchell', 'speero']):
            for image in glob.glob(f'images/transformed/{cat}/*.jpg'):
                writer.writerow([image, index])

In [16]:
get_transformed()

In [17]:
create_csv()

In [13]:
for image in glob.glob(f'images/transformed/*/*.jpg'):
    if Image.open(image).size != (224, 224):
        print(image)