In [None]:
from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img, array_to_img, save_img
import os
from matplotlib import pyplot as plt
import numpy as np
import os
from constants import class_type_to_index
import math

In [None]:
def augment_single_class(dir_path, filename_tupels, generator, factor, limit):
    count = len(filename_tupels)
    for iteration in range(factor):
        for filename_tuple in filename_tupels:
            if count >= limit:
                break

            img = load_img(dir_path + '/' + filename_tuple[0] + '/' + filename_tuple[1])
            x = img_to_array(img)
            x_transformed = generator.random_transform(x)

            should_mess_up_colors = np.random.choice([0,1])
            if should_mess_up_colors == 1:
                x_transformed = mess_up_colors(x_transformed)
            else:
                x_transformed = swith_colors(x_transformed)

            img_transformed = array_to_img(x_transformed)
            save_img(dir_path +'/' + filename_tuple[0] + '/_' + str(iteration) + '_' + filename_tuple[1], img_transformed)
            
            count += 1

def swith_colors(x):
    channels = np.random.permutation(x.shape[2])
    x = x[:,:,channels]
    return x

def mess_up_colors(x):
    for channel in range(3):
        intensity = np.random.uniform(0,1.5)
        x[:,:,channel] *= intensity
    return np.clip(x, 0, 255)

In [None]:
def get_dirnames(path):
    dirnames = sorted(os.listdir(path))
    try:
        dirnames.remove('.DS_Store')
    except:
        x=1
    dirnames = dirnames[0:91] #todo: use all classes
    dirnames = np.array(dirnames)
    
    return dirnames

def get_filenames(path, dirname):
    files = os.listdir(path + dirname)
    try:
        files.remove('.DS_Store')
    except:
        x=1
    return files

def divide_files_into_classes(class_types, path):
    classes_to_files = {}
    dirnames = get_dirnames(path)
    class_indices = list(map(lambda t: class_type_to_index[t], class_types))
    for dirname in dirnames:
        classname = ''
        if dirname == '.DS_Store':
            continue
        classnames = dirname.split(' ')
        for index in class_indices:
            classname += classnames[index] + ' '
        classname = classname.strip()
        
        filenames = get_filenames(path, dirname)
        if (len(filenames) == 0):
            print(path, dirname)
        filename_tuples = list(map(lambda f: (dirname, f), filenames))
        
        if classname in classes_to_files:
            classes_to_files[classname] = np.concatenate([classes_to_files[classname], np.array(filename_tuples)])
        else:
            classes_to_files[classname] = np.array(filename_tuples)
    
    return classes_to_files

def augment(data_dir, class_types, generator, factor):
    classes_to_files = divide_files_into_classes(class_types, data_dir)

    max_files_in_class = 0
    biggest_class = ''
    for cl, filename_tuples in classes_to_files.items():
        if len(filename_tuples) > max_files_in_class:
            max_files_in_class = len(filename_tuples)
            biggest_class = cl
    
    augmented_num_files = factor * max_files_in_class
    print('Augmented number of files in each class:', augmented_num_files)
    print('Biggest class:', biggest_class)
    
    count = 0
    for c, filename_tuples in classes_to_files.items():
        class_factor = math.ceil(augmented_num_files / len(filename_tuples))
        augment_single_class(data_dir, filename_tuples, generator, class_factor, augmented_num_files)
        print('Class ' + str(count) + ' ' + c + ' augmented.')
        count += 1


In [None]:
generator = ImageDataGenerator(
    horizontal_flip=True,
    rotation_range=15,
    shear_range=15,
    brightness_range=(0.5, 2),
    fill_mode='constant',
    cval=0,
)

## Preview

In [None]:
img = load_img('stanford-car-dataset-by-classes-folder/datasets/train_dataset/Acura RL Sedan 2012/00249.jpg')
fig = plt.figure(figsize=(10,20))
plt.imshow(img)
x = img_to_array(img)

for i in range(20):
    x_transformed = generator.random_transform(x)
    should_mess_up_colors = np.random.choice([0,1])
    if should_mess_up_colors == 1:
        x_transformed = mess_up_colors(x_transformed)
    else:
        x_transformed = swith_colors(x_transformed)
    img_transformed = array_to_img(x_transformed)
    fig = plt.figure(figsize=(10,20))
    plt.imshow(img_transformed)

## Augmentation

In [None]:
augment('./stanford-car-dataset-by-classes-folder/train_dataset/', ['make', 'model'], generator, 2)