In [76]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import pdb
from autoaugment import ImageNetPolicy, CIFAR10Policy, SVHNPolicy, SubPolicy
from collections import defaultdict
import os
from matplotlib.pyplot import imshow

In [77]:
policy = ImageNetPolicy()

In [1]:
def getAugmentedImages(original_image_path, no_of_aug_images=8):
    """
    Produce auto-augmented images for the images at original_image_path
    
    Args:
        original_image_path : string path to an image
        no_of_aug_images : number of augmented images to generate
    """
    original_image = Image.open(original_image_path)
    
    # ImageNetPolicy() cannot augment "RGBA" types so we convert it to "RGB" before passing into ImageNetPolicy()
    if original_image.mode == "RGBA":
        original_image.load()
        new = Image.new("RGB", original_image.size, (255, 255, 255))
        new.paste(original_image, mask=original_image.split()[3])
        original_image = new
    
    imgs = []
    for _ in range(no_of_aug_images): 
        try:
            imgs.append(policy(original_image))
        except: # various kinds of errors due to ImageNetPolicy() API itself while generating a certain kind of image
            print("Some error encountered")
            continue
    return imgs

In [79]:
def augmentMultipleLabels(dict_labels_to_photos, n_images_per_label = 400):
    """
    Get augmented images for all the images in different labels and store them in the provided
    dictionary
    
    Args:
        dict_settings_to_photos : a dictionary of labels mapping to dictionaries whose keys are the paths to images
                                  in that setting and whose value is currently set to none, to be replaced by
                                  a list containing augmented images of the key image. 
                                  Ex: d = {"cat":{"data/train/cat/0.jpg" : None, 
                                                  "data/train/cat/1.jpg" : None},
                                           "dog":{"data/train/dog/43.jpg" : None}}
                                  None will be later replced by a list of PIL image types
        n_images_per_setting : total number of images approx available to a setting after augmentation
    """
    for setting in dict_settings_to_photos:
        no_photos_in_setting = len(dict_settings_to_photos[setting])
        no_aug_per_image = math.ceil((n_images_per_setting - no_photos_in_setting)/no_photos_in_setting)
        
        for image in dict_settings_to_photos[setting]:
            # setting None value to list of augmented images
            dict_settings_to_photos[setting][image] = getAugmentedImages(image, int(no_aug_per_image))

In [80]:
def saveImages(rootdir, root_name, images):
    """
    Save a list of images into the rootdir with the images having prefix root_name
    
    Args:
        rootdir : directory where the images should be saved
        root_name : prefix of the image name
        images : list of image files that should be saved
    """
    for i, image in enumerate(images):
        image_name = "{0}_{1}.jpg".format(root_name, i)
        img_path = os.path.join(rootdir, image_name)
        image.save(img_path)

In [81]:
common_img_extensions = {'.tif', '.tiff', '.gif', '.jpeg', '.jpg', '.jif', '.jfif',
                         '.jp2', '.jpx', '.j2k', '.j2c', '.fpx', '.pcd', '.png'}

In [2]:
def makeAugmented(directory, classes):
    """
    Create auto-augmented images of all the images in a directory
    
    Args:
        directory : string path to the directory containing the folders (folders represent labels) that contain
                    the images to augment
        classes : list of labels on whose images auto-augmentation is to be performed
    """
    # aug is a JSON type dict mapping settings to images (images maps to list of augmented images)
    
    aug = defaultdict(dict)
    for root,_,files in os.walk(directory):
        for f in files:
            # check if file is an image type
            extension = os.path.splitext(f)[1]
            if not extension in common_img_extensions:
                continue
            
            label = os.path.split(root)[-1] # from path, get the folder name (setting)
            if label in classes:
                ffull = os.path.join(root,f)
                if label not in aug:
                    aug[label] = {}
                aug[label][ffull] = None
    
    # once aug has all the settings and their corresponding images, we need to generate the augmented images
    # for each of those images
    
    augmentMultipleSettings(aug, no_aug_per_setting = 200)
    
    # path to directory where all the augmented images will be stored
    augment_dir = os.path.join(directory, "augment")
    if not os.path.exists(augment_dir):
        os.makedirs(augment_dir)
    
    # store all the augmented images in augment directory which further contains directories like "cat" and "dog"
    for label in aug:
        # label directory inside augment directory
        aug_label_dir = os.path.join(augment_dir, label)
        if not os.path.exists(aug_label_dir):
            os.makedirs(aug_label_dir)
        
        for image in aug[label]:
            image_id = os.path.splitext(os.path.split(image)[-1])[0] # image = "Users/dave/1.jpg" gives image_id = "1"
            saveImages(aug_label_dir, image_id, aug[label][image])
    

        

In [83]:
directory = "data/train"
classes = ['cat', 'dog', 'rabbit']

In [None]:
makeAugmented(directory, classes)